From: Alexander Schmidt Date: Tue, 3 Nov 2020 10:39:54 +0000 (+0100) Subject: added 1on1 X-Git-Url: http://git.treefish.org/~alex/shutbox.git/commitdiff_plain/3dd7a576f76815f767c9dcad51476c26cbed4291 added 1on1 --- diff --git a/src/1on1.py b/src/1on1.py new file mode 100755 index 0000000..b88ab9b --- /dev/null +++ b/src/1on1.py @@ -0,0 +1,79 @@ +#!/usr/bin/env python3 + +import numpy as np +import random +import sys + +from game import Game + +learning_rate = 0.1 +discount_factor = 1.0 + +states_dim = 36864 # 2^10 * 6^2 +actions_dim = 539 # 10+1 * (6+1)^2 +num_episodes = 10000000000 + +def find_state_qid(shutable, diced): + qid = 0 + for rod in shutable: + qid += pow(2, rod-1) + for i in range(len(diced)): + qid += (diced[i]-1) * pow(6, i) * pow(2, 10) + return qid + +def find_option_qid(option): + qid = 0 + for i in range(len(option)): + qid += option[i] * pow(7, i) * pow(11, len(option)-1) + return qid + +def select_option(opts, qs): + opt_qid_pairs = [] + opt_qsum = 0.0 + for opt in opts: + opt_qid = find_option_qid(opt) + opt_qid_pairs.append( [opt, opt_qid] ) + opt_qsum += qs[opt_qid] + ran_pt = random.uniform(0.0, opt_qsum) + decision_pt = 0.0 + for opt_qid_pair in opt_qid_pairs: + decision_pt += qs[ opt_qid_pair[1] ] + if ran_pt <= decision_pt: + return (opt_qid_pair[0], opt_qid_pair[1]) + +Q = np.ones([states_dim, actions_dim]) + +running_score = [0.0, 0.0] +stats = [1.0, 1.0] + +for i in range(num_episodes): + g = Game() + g.dice() + state_qid = find_state_qid(g.get_shutable(), g.get_diced()) + num_turn = random.randint(0, 1) + while not g.is_over(): + options = g.get_options() + if len(options) > 0: + if num_turn % 2 == 0: + opt, opt_qid = select_option( options, Q[state_qid, :] ) + g.shut(opt) + g.dice() + reward = 0.0 + new_state_qid = find_state_qid(g.get_shutable(), g.get_diced()) + Q[state_qid, opt_qid] += \ + learning_rate * (reward + + discount_factor * np.max(Q[new_state_qid, :]) + - Q[state_qid, opt_qid]) + state_qid = new_state_qid + else: + choice = random.randint(0, len(options) - 1) + g.shut(options[choice]) + g.dice() + state_qid = find_state_qid(g.get_shutable(), g.get_diced()) + num_turn += 1 + if num_turn % 2 == 0: + stats[1] += 1.0 + Q[state_qid, :] = 0.0 + else: + stats[0] += 1.0 + print(stats[0]/stats[1])