import chainer
import chainer.functions as F
import chainer.links as L
import chainerrl
from env_line_trace import EnvLineTrace
import numpy as np
import argparse

"""
以下のような2層の全結合ニューラルネットワークを実装
  入力：状態
  出力：最適な行動
"""
class QFunction(chainer.Chain):
    def __init__(self, obs_size, n_actions, n_hidden_channels=50):
        super().__init__()
        with self.init_scope():
            self.l0 = L.Linear(obs_size, n_hidden_channels)
            self.l1 = L.Linear(n_hidden_channels, n_hidden_channels)
            self.l2 = L.Linear(n_hidden_channels, n_actions)

    def __call__(self, x, test=False):
        h = F.tanh(self.l0(x))
        h = F.tanh(self.l1(h))
        return chainerrl.action_value.DiscreteActionValue(self.l2(h))


# エージェントを定義
def get_agent(env, obs_size, n_actions):
    # 上記のニューラルネットワークを利用
    q_func = QFunction(obs_size, n_actions)
    optimizer = chainer.optimizers.Adam(eps=1e-2)
    optimizer.setup(q_func)

    # DoubleDQNに必要なパラメータを定義
    gamma = 0.95
    explorer = chainerrl.explorers.ConstantEpsilonGreedy(
        epsilon=1, random_action_func=env.action_space.sample)
    replay_buffer = chainerrl.replay_buffer.ReplayBuffer(capacity=10 ** 6)
    phi = lambda x: x.astype(np.float32, copy=False)

    # ChainerRLに用意されたDoubleDQNアルゴリズムを利用
    agent = chainerrl.agents.DoubleDQN(
        q_func, optimizer, replay_buffer, gamma, explorer,
        replay_start_size=500, update_interval=1,
        target_update_interval=100, phi=phi)
    return agent


# 学習を定義
def train(save_dir, n_episodes=100, max_episode_len=200):
    # 自分で用意した環境を利用
    env = EnvLineTrace(5, 7, 'target.png')
    obs_size = env.observation_space.shape[0]
    n_actions = env.action_space.n

    agent = get_agent(env, obs_size, n_actions)

    # n_episodesの数だけ試行を繰り返す
    for i in range(1, n_episodes + 1):
        # シミュレーターを初期位置に戻す
        obs = env.reset()
        reward = 0
        done = False
        R = 0
        t = 0

        # １試行の中で行動を繰り返す
        while not done and t < max_episode_len:
            env.render()
            action = agent.act_and_train(obs, reward)
            obs, reward, done, _ = env.step(action)
            R += reward
            t += 1
        if i % 20 == 0:
            print('episode:', i,
                  'R:', R,
                  'statistics:', agent.get_statistics())
        agent.stop_episode_and_train(obs, reward, done)

    # 学習したモデルを保存する
    agent.save(save_dir)


# 学習済みのモデルで動作確認
def test_episode(save_dir, test_steps=500):
    # 自分で用意した環境を利用
    env = EnvLineTrace()
    obs_size = env.observation_space.shape[0]
    n_actions = env.action_space.n

    agent = get_agent(env, obs_size, n_actions)

    agent.load(save_dir)

    for i in range(10):
        obs = env.reset()
        done = False
        R = 0
        t = 0
        while not done and t < test_steps:
            env.render()
            action = agent.act(obs)
            obs, r, done, _ = env.step(action)
            R += r
            t += 1
        agent.stop_episode()


def main():
    # コマンドライン引数を定義
    parser = argparse.ArgumentParser(description='Learner and tester for agent of tictactoe game.')
    parser.add_argument('--test', const=True, nargs='?', help='学習済みのエージェントをテストします')
    parser.add_argument('-d', '--save_dir', type=str, default='line_trace_agent', nargs='?', help='エージェントの保存／読み込み先ディレクトリ')
    args = parser.parse_args()

    if args.test:
        test_episode(args.save_dir)
    else:
        train(args.save_dir)


if __name__ == '__main__':
    main()
