1 Star 1 Fork 0

liuruoze / Thought-SC2

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
克隆/下载
mini_network_add_map_bn.py 9.91 KB
一键复制 编辑 原始数据 按行查看 历史
liuruoze 提交于 2020-09-17 09:31 . * first commit
import numpy as np
import tensorflow as tf
import param as P
from algo.ppo_add_map_bn import Policy_net, PPOTrain
# for mini game
_SIZE_MINI_INPUT = 20
_SIZE_MINI_ACTIONS = 10
class MiniNetwork(object):
def __init__(self, sess=None, summary_writer=tf.summary.FileWriter("logs/"), rl_training=False,
reuse=False, cluster=None, index=0, device='/gpu:0',
ppo_load_path=None, ppo_save_path=None,
ob_space_add=0, act_space_add=0,
freeze_head=False, use_bn=True,
use_sep_net=True, load_latest=False,
restore_model=False, restore_from=None, restore_to=None,
add_image=False, partial_restore=True, weighted_sum_type='AddWeight',
initial_type="original"):
self.policy_model_path_load = ppo_load_path + "mini"
self.latest_model_path_load = ppo_load_path + "latest"
self.policy_model_path_save = ppo_save_path + "mini"
self.latest_model_path_save = ppo_save_path + "latest"
self.rl_training = rl_training
self.reuse = reuse
self.sess = sess
self.cluster = cluster
self.index = index
self.device = device
self.ob_space_add = ob_space_add
self.act_space_add = act_space_add
self.add_image = add_image
self.freeze_head = freeze_head
self.use_bn = use_bn
self.use_sep_net = use_sep_net
self.restore_model = restore_model
self.restore_from = restore_from
self.restore_to = restore_to
self.load_latest = load_latest
self.partial_restore = partial_restore
self.weighted_sum_type = weighted_sum_type
self.initial_type = initial_type
if self.ob_space_add == 0 and self.act_space_add == 0 and self.add_image == False:
self.use_add = False
self.lr=P.mini_lr
self.epoch_num = P.mini_epoch_num
else:
self.use_add = True
self.lr=P.mini_lr_add
self.epoch_num = P.mini_epoch_num
self._create_graph()
self.rl_saver = tf.train.Saver()
self.summary_writer = summary_writer
def initialize(self):
init_op = tf.global_variables_initializer()
self.sess.run(init_op)
def reset_old_network(self):
self.policy_ppo.assign_policy_parameters()
self.policy_ppo.reset_mean_returns()
self.sess.run(self.results_sum.assign(0))
self.sess.run(self.game_num.assign(0))
def _create_graph(self):
if self.reuse:
tf.get_variable_scope().reuse_variables()
assert tf.get_variable_scope().reuse
worker_device = "/job:worker/task:%d" % self.index + self.device
with tf.device(tf.train.replica_device_setter(worker_device=worker_device, cluster=self.cluster)):
self.results_sum = tf.get_variable(name="results_sum", shape=[], initializer=tf.zeros_initializer)
self.game_num = tf.get_variable(name="game_num", shape=[], initializer=tf.zeros_initializer)
self.global_step = tf.get_variable(name="global_step", shape=[], dtype=tf.int32,
initializer=tf.zeros_initializer, trainable=False)
self.global_steps = tf.get_variable(name="iter_steps", shape=[], dtype=tf.int32,
initializer=tf.zeros_initializer, trainable=False)
self.win_rate = self.results_sum / self.game_num
self.mean_win_rate = tf.summary.scalar('mean_win_rate_dis', self.results_sum / self.game_num)
self.merged = tf.summary.merge([self.mean_win_rate])
mini_scope = "MiniPolicyNN"
with tf.variable_scope(mini_scope):
ob_space = _SIZE_MINI_INPUT
act_space_array = _SIZE_MINI_ACTIONS
self.policy = Policy_net('policy', self.sess, ob_space, self.ob_space_add,
act_space_array, self.act_space_add, self.freeze_head, self.use_bn, self.use_sep_net,
self.add_image, self.weighted_sum_type, self.initial_type)
self.policy_old = Policy_net('old_policy', self.sess, ob_space, self.ob_space_add,
act_space_array, self.act_space_add, self.freeze_head, self.use_bn, self.use_sep_net,
self.add_image, self.weighted_sum_type, self.initial_type)
self.policy_ppo = PPOTrain('PPO', self.sess, self.policy, self.policy_old, lr=self.lr, epoch_num=self.epoch_num)
var_train_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
var_all_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
if self.restore_model:
print('restore_model')
if self.restore_from == 'mini' and self.restore_to == 'mini':
print('restore_model: mini to mini')
self.old_policy_saver = tf.train.Saver(var_list=var_all_list)
elif self.restore_from == 'mini' and self.restore_to == 'source':
print('restore_model: mini to source')
if self.use_add:
variables_to_restore = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=self.policy.scope)
old_variables_to_restore = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=self.policy_old.scope)
variables_to_restore += old_variables_to_restore
# remove layers for added obs
variables_to_restore = [v for v in variables_to_restore if len(v.name.split('/')) > 2 and 'DenseLayer3' not in v.name.split('/')]
# remove layer weight for added action
variables_to_restore = [v for v in variables_to_restore if len(v.name.split('/')) > 2 and v.name.split('/')[-2] != 'add_output_layer']
if not self.partial_restore:
variables_to_restore = [v for v in variables_to_restore if len(v.name.split('/')) > 2 and v.name.split('/')[-2] != 'output']
# remove layer weight for weighted sum
variables_to_restore = [v for v in variables_to_restore if len(v.name.split('/')) > 2 and
'AdaptiveWeight:0' not in v.name.split('/') and 'AttentionWeight' not in v.name.split('/')]
if self.add_image:
map_variables_to_restore = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=self.policy.map_variable_scope)
old_map_variables_to_restore = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=self.policy_old.map_variable_scope)
# remove layers for added map
variables_to_restore = [v for v in variables_to_restore if v not in map_variables_to_restore and v not in old_map_variables_to_restore]
print('restore_model: mini to source, use_add')
#print('variables_to_restore:', variables_to_restore)
self.old_policy_saver = tf.train.Saver(var_list=variables_to_restore)
else:
self.old_policy_saver = tf.train.Saver(var_list=var_all_list)
elif self.restore_from == 'source' and self.restore_to == 'source':
self.old_policy_saver = tf.train.Saver(var_list=var_all_list)
else:
self.old_policy_saver = tf.train.Saver(var_list=var_all_list)
else:
self.old_policy_saver = tf.train.Saver(var_list=var_all_list)
self.new_policy_saver = tf.train.Saver(var_list=var_all_list)
def Update_result(self, result_list):
win = 0
for i in result_list:
if i > 0:
win += 1
self.sess.run(self.results_sum.assign_add(win))
self.sess.run(self.game_num.assign_add(len(result_list)))
def Update_summary(self, counter):
print("Update summary........")
policy_summary = self.policy_ppo.get_summary_dis()
self.summary_writer.add_summary(policy_summary, counter)
summary = self.sess.run(self.merged)
self.summary_writer.add_summary(summary, counter)
print("Update summary finished!")
self.sess.run(self.global_steps.assign(counter))
steps = int(self.sess.run(self.global_steps))
win_game = int(self.sess.run(self.results_sum))
all_game = int(self.sess.run(self.game_num))
#print('all_game:', all_game)
win_rate = win_game / float(all_game) if all_game != 0 else 0.
return steps, win_rate
def get_win_rate(self):
return float(self.sess.run(self.win_rate))
def Update_policy(self, buffer):
self.policy_ppo.ppo_train_dis(buffer.observations, buffer.obs_add, buffer.obs_map, buffer.tech_actions,
buffer.rewards, buffer.values, buffer.values_next, buffer.gaes, buffer.returns,
buffer.return_values, self.index, self.summary_writer)
def get_global_steps(self):
return int(self.sess.run(self.global_steps))
def save_policy(self):
self.new_policy_saver.save(self.sess, self.policy_model_path_save)
print("policy has been saved in", self.policy_model_path_save)
def save_latest_policy(self):
self.new_policy_saver.save(self.sess, self.latest_model_path_save)
print("latest policy has been saved in", self.latest_model_path_save)
def restore_policy(self):
if self.load_latest:
self.old_policy_saver.restore(self.sess, self.latest_model_path_load)
print("Restore policy from", self.latest_model_path_load)
else:
self.old_policy_saver.restore(self.sess, self.policy_model_path_load)
print("Restore policy from", self.policy_model_path_load)
Python
1
https://gitee.com/liuruoze/Thought-SC2.git
git@gitee.com:liuruoze/Thought-SC2.git
liuruoze
Thought-SC2
Thought-SC2
master

搜索帮助