when-to-switch
145 строк · 5.2 Кб
1import numpy as np
2import cppimport.import_hook
3from planning.planner import planner
4from pogema import GridConfig
5
6from planning.astar_no_grid import AStar, INF
7
8
9class RePlanBase:
10def __init__(self, use_best_move: bool = True, max_steps: int = INF, algo_source='c++', seed=None,
11ignore_other_agents=False):
12
13self.use_best_move = use_best_move
14gc: GridConfig = GridConfig()
15
16self.actions = {tuple(gc.MOVES[i]): i for i in range(len(gc.MOVES))}
17self.steps = 0
18
19if algo_source == 'c++':
20self.algo_source = planner
21else:
22self.algo_source = AStar
23self.planner = None
24self.max_steps = max_steps
25self.previous_positions = None
26self.rnd = np.random.default_rng(seed)
27self.ignore_other_agents = ignore_other_agents
28
29def act(self, obs, skip_agents=None):
30num_agents = len(obs)
31if self.planner is None:
32self.planner = [self.algo_source(self.max_steps) for _ in range(num_agents)]
33if self.previous_positions is None:
34self.previous_positions = [[] for _ in range(num_agents)]
35obs_radius = len(obs[0]['obstacles']) // 2
36action = []
37
38for k in range(num_agents):
39self.previous_positions[k].append(obs[k]['xy'])
40if obs[k]['xy'] == obs[k]['target_xy']:
41action.append(None)
42continue
43obstacles = np.transpose(np.nonzero(obs[k]['obstacles']))
44if self.ignore_other_agents:
45other_agents = None
46else:
47other_agents = np.transpose(np.nonzero(obs[k]['agents']))
48self.planner[k].update_obstacles(obstacles, other_agents,
49(obs[k]['xy'][0] - obs_radius, obs[k]['xy'][1] - obs_radius))
50
51if skip_agents and skip_agents[k]:
52action.append(None)
53continue
54
55self.planner[k].update_path(obs[k]['xy'], obs[k]['target_xy'])
56path = self.planner[k].get_next_node(self.use_best_move)
57if path is not None and path[1][0] < INF:
58action.append(self.actions[(path[1][0] - path[0][0], path[1][1] - path[0][1])])
59else:
60action.append(None)
61self.steps += 1
62return action
63
64def get_path(self):
65results = []
66for idx in range(len(self.planner)):
67results.append(self.planner[idx].get_path(use_best_node=False))
68return results
69
70
71class FixNonesWrapper:
72
73def __init__(self, agent):
74self.agent = agent
75self.rnd = self.agent.rnd
76# self.env = agent.env
77
78def act(self, obs, skip_agents=None):
79actions = self.agent.act(obs, skip_agents=skip_agents)
80for idx in range(len(actions)):
81if actions[idx] is None:
82actions[idx] = 0
83return actions
84
85
86class NoPathSoRandomOrStayWrapper:
87
88def __init__(self, agent):
89self.agent = agent
90self.rnd = self.agent.rnd
91
92def act(self, obs, skip_agents=None):
93actions = self.agent.act(obs, skip_agents=skip_agents)
94for idx in range(len(actions)):
95if actions[idx] is None:
96if self.rnd.random() <= 0.5:
97actions[idx] = 0
98else:
99actions[idx] = self.get_random_move(obs, idx)
100return actions
101
102def get_random_move(self, obs, agent_id):
103deltas = GridConfig().MOVES
104actions = [1, 2, 3, 4]
105
106self.agent.rnd.shuffle(actions)
107for idx in actions:
108i = len(obs[agent_id]['obstacles']) // 2 + deltas[idx][0]
109j = len(obs[agent_id]['obstacles']) // 2 + deltas[idx][1]
110if obs[agent_id]['obstacles'][i][j] == 0:
111return idx
112return 0
113
114
115class FixLoopsWrapper(NoPathSoRandomOrStayWrapper):
116def __init__(self, agent, stay_if_loop_prob=None, add_none_if_loop=False):
117super().__init__(agent)
118self.rnd = self.agent.rnd
119self.previous_positions = None
120self.stay_if_loop_prob = stay_if_loop_prob if stay_if_loop_prob else 0.5
121self.add_none_if_loop = add_none_if_loop
122
123def act(self, obs, skip_agents=None):
124num_agents = len(obs)
125if self.previous_positions is None:
126self.previous_positions = [[] for _ in range(num_agents)]
127
128actions = self.agent.act(obs, skip_agents=skip_agents)
129for idx in range(len(actions)):
130if actions[idx] is None:
131continue
132path = self.previous_positions[idx]
133if len(path) > 1:
134next_step = obs[idx]['xy']
135dx, dy = GridConfig().MOVES[actions[idx]]
136next_pos = dx + next_step[0], dy + next_step[1]
137if path[-1] == next_pos or path[-2] == next_pos:
138if self.add_none_if_loop:
139actions[idx] = None
140elif next_pos == next_step:
141actions[idx] = self.get_random_move(obs, idx)
142elif self.rnd.random() < self.stay_if_loop_prob:
143actions[idx] = 0
144self.previous_positions[idx].append(obs[idx]['xy'])
145return actions
146