#!/usr/bin/env python3

import numpy as np
import random
import sys

from game import Game

learning_rate = 0.1
discount_factor = 1.0

states_dim = 36864 # 2^10 * 6^2
actions_dim = 539 # 10+1 * (6+1)^2
num_episodes = 10000000000

def find_state_qid(shutable, diced):
    qid = 0
    for rod in shutable:
        qid += pow(2, rod-1)
    for i in range(len(diced)):
        qid += (diced[i]-1) * pow(6, i) * pow(2, 10)
    return qid

def find_option_qid(option):
    qid = 0
    for i in range(len(option)):
        qid += option[i] * pow(7, i) * pow(11, len(option)-1)
    return qid

def select_option(opts, qs):
    opt_qid_pairs = []
    opt_qsum = 0.0
    for opt in opts:
        opt_qid = find_option_qid(opt)
        opt_qid_pairs.append( [opt, opt_qid] )
        opt_qsum += qs[opt_qid]
    ran_pt = random.uniform(0.0, opt_qsum)
    decision_pt = 0.0
    for opt_qid_pair in opt_qid_pairs:
        decision_pt += qs[ opt_qid_pair[1] ]
        if ran_pt <= decision_pt:
            return (opt_qid_pair[0], opt_qid_pair[1])

Q = np.ones([states_dim, actions_dim])

running_score = [0.0, 0.0]
stats = [1.0, 1.0]

for i in range(num_episodes):
    g = Game()
    g.dice()
    state_qid = find_state_qid(g.get_shutable(), g.get_diced())
    num_turn = random.randint(0, 1)
    while not g.is_over():
        options = g.get_options()
        if len(options) > 0:
            if num_turn % 2 == 0:
                opt, opt_qid = select_option( options, Q[state_qid, :] )
                g.shut(opt)
                g.dice()
                reward = 0.0
                new_state_qid = find_state_qid(g.get_shutable(), g.get_diced())
                Q[state_qid, opt_qid] += \
                    learning_rate * (reward
                                     + discount_factor * np.max(Q[new_state_qid, :])
                                     - Q[state_qid, opt_qid])
                state_qid = new_state_qid
            else:
                choice = random.randint(0, len(options) - 1)
                g.shut(options[choice])
                g.dice()
                state_qid = find_state_qid(g.get_shutable(), g.get_diced())
            num_turn += 1
    if num_turn % 2 == 0:
        stats[1] += 1.0
        Q[state_qid, :] = 0.0
    else:
        stats[0] += 1.0
    print(stats[0]/stats[1])
