import gym
import gym.spaces
import cv2
import numpy as np


class EnvLineTrace(gym.core.Env):
    '''
    observation_spaceを変更することでセンサーの個数を変更できる
    action_spaceを変更することで進行可能な向きの個数を変更できる
    target_img_pathを変更することでトレース対象の画像を変更できる
    '''
    def __init__(self, observation_space=5, action_space=7, target_img_path='target.png'):
        self.obs_size = observation_space
        self.action_space = gym.spaces.Discrete(action_space)
        self.observation_space = gym.spaces.Box(low=-0.0, high=1.0, shape=(observation_space,))
        self.direction_list = [(x * np.pi / (action_space - 1)) - (np.pi / 2) for x in range(action_space)]

        # Tracerの中心位置
        self.pos_x = 0
        self.pos_y = 0

        # Tracerの向き (0～2πで表現)
        self.direction = 0

        # トレース対象の読み込み
        # トレース対象は白黒の画像とする
        self.target_img = cv2.imread(target_img_path, cv2.IMREAD_GRAYSCALE)
        self.target_img_height, self.target_img_width = self.target_img.shape[:2]

        # Tracerの大きさを定義
        self.radius = 20
        # 1stepあたりの移動距離を定義
        self.distance = 5

        self.step_count = 0
        self.max_episode_len = 10000

        # センサーの位置を取得する
        self.pos_sensor_list = self.get_sensor_pos()

    '''
    エージェントからactionを受け取り、その行動後の状態や報酬を返す。
    1. エージェントから受け取ったactionに従って、Tracerを移動させる
    2. 移動先でセンサー情報を取得する
    3. センサー情報に基づいて報酬の計算を行う
    4. 試行を終わらせるかどうかを判断する
    5. 状態、報酬、試行終了の判断結果　をエージェントに返す
    '''
    def step(self, action):
        done = False

        # actionに従って移動する
        self.direction = self.direction + self.direction_list[action]
        self.pos_x = self.pos_x + self.distance * np.cos(self.direction)
        self.pos_y = self.pos_y + self.distance * np.sin(self.direction)

        # 移動先でセンサー情報を取得する
        self.pos_sensor_list = self.get_sensor_pos()
        state = np.array([1.0 if np.sum(self.target_img[int(x), int(y)]) == 0 else 0.0 for (y, x) in self.pos_sensor_list])

        # 報酬を計算する
        # 黒に反応したセンサーの個数が多いほど点数が増え、最大1を与える
        # 黒に反応したセンサーが無い場合は-1を与える
        reward = np.mean(state) if np.sum(state) != 0 else -1

        # Tracerが場外に出たら試行を終了する
        # 報酬は-10を与える
        if self.pos_x < self.radius or self.pos_x > self.target_img_width - self.radius or self.pos_y < self.radius \
                or self.pos_y > self.target_img_height - self.radius:
            done = True
            reward = -10

        # 指定のstep数経過したら試行を終了する
        if self.step_count > self.max_episode_len:
            done = True
        else:
            self.step_count += 1

        return state, reward, done, {}

    # 環境を初期化して状態を返す
    def reset(self):
        # Tracerの中心位置を初期化
        self.pos_x = 400
        self.pos_y = 80

        # Tracerの向き (0～2πで表現)を初期化
        self.direction = 0

        # センサーの位置を取得
        self.pos_sensor_list = self.get_sensor_pos()

        # step数のカウントを初期化
        self.step_count = 0

        # OpenCV2のウィンドウを破棄する
        cv2.destroyAllWindows()

        return np.array([1.0 if np.sum(self.target_img[int(x), int(y)]) == 0 else 0.0 for (y, x) in self.pos_sensor_list])

    # 学習の様子をOpenCV2を使用して可視化する
    def render(self):
        target_for_render = np.copy(cv2.cvtColor(self.target_img, cv2.COLOR_GRAY2RGB))
        # Tracerを表示
        cv2.circle(target_for_render, center=(int(self.pos_x), int(self.pos_y)),
                   radius=self.radius, color=(160, 0, 0), thickness=2)
        # Tracerのセンサーを表示
        for pos_sensor in self.pos_sensor_list:
            cv2.circle(target_for_render, center=(int(pos_sensor[0]), int(pos_sensor[1])),
                     radius=int(self.radius / (self.obs_size + 1)), color=(255, 0, 255), thickness=1)
        cv2.imshow('render', target_for_render)
        key_code = cv2.waitKey(5)
        return key_code

    # Tracerのセンサー位置を取得する
    def get_sensor_pos(self):
        # Tracerの進行方向側、直径の半分の位置に横一列にセンサーを等間隔に配置する
        sensor_center_x = self.pos_x + np.cos(self.direction) * self.radius / 2
        sensor_center_y = self.pos_y + np.sin(self.direction) * self.radius / 2

        sensor_mark_list = np.linspace(-1/2, 1/2, self.obs_size)
        return [(sensor_center_x - np.sin(self.direction) * self.radius * num,
                    sensor_center_y + np.cos(self.direction) * self.radius * num) for num in sensor_mark_list]
