MAML-RL Pytorch 代码解读 (10) -- maml
阿里云国内75折 回扣 微信号:monov8 |
阿里云国际,腾讯云国际,低至75折。AWS 93折 免费开户实名账号 代冲值 优惠多多 微信号:monov8 飞机:@monov6 |
MAML-RL Pytorch 代码解读 (10) – maml_rl/envs/subproc_vec_env.py
文章目录
基本介绍
在网上看到的元学习 MAML 的代码大多是跟图像相关的强化学习这边的代码比较少。
因为自己的思路跟 MAML-RL 相关所以打算读一些源码。
MAML 的原始代码是基于 tensorflow 的在 Github 上找到了基于 Pytorch 源码包学习这个包。
源码链接
https://github.com/dragen1860/MAML-Pytorch-RL
文件路径
./maml_rl/envs/subproc_vec_env.py
import
包
import numpy as np
import multiprocessing as mp
import gym
import sys
import queue
EnvWorker()
类
class EnvWorker(mp.Process):
def __init__(self, remote, env_fn, queue_, lock):
#### 类初始化继承了mp.Process()类应该是一个做多线程的类。结合名字应该是将线程、函数信息、队列和锁传递给了“环境工人”那么将发出器传出的指令接受到“环境工人”中。
"""
:param remote: send/recv connection, type of Pipe
:param env_fn: construct environment function
:param queue_: global queue instance
:param lock: Every worker has a lock
"""
super(EnvWorker, self).__init__()
self.remote = remote # Pipe()
self.env = env_fn() # return a function
self.queue = queue_
self.lock = lock
self.task_id = None
self.done = False
#### 这个函数是用来进行一步啥也不干的时间步。观测信息是全0向量都没有奖励为0done标记是True。
def empty_step(self):
"""
conduct a dummy step
:return:
"""
observation = np.zeros(self.env.observation_space.shape, dtype=np.float32)
reward, done = 0.0, True
return observation, reward, done, {}
#### 这个函数应该是重置环境的意思。self.lock应该是进程的锁当锁被打开时依次读取任务的id号码当队列里面没有id号码的时候执行except异常处理self.done设置为True。当队列里面还有id号码的时候self.done设置为False。
def try_reset(self):
"""
:return:
"""
with self.lock:
try:
self.task_id = self.queue.get(True) # block = True
self.done = (self.task_id is None)
except queue.Empty:
self.done = True
#### 如果self.done设置为True说明队列没有任务了观测信息就清空否则就重置环境。
# construct empty state or get state from env.reset()
observation = np.zeros(self.env.observation_space.shape, dtype=np.float32) if self.done else self.env.reset()
return observation
#### 对线程做一些处理。
def run(self):
"""
:return:
"""
while True:
#### 应该是从管道中接受到数据和指令类似于串口输入
command, data = self.remote.recv()
#### 如果指令是'step'说明就要在环境中采取一次时间步获得下一时刻的观测信息、奖励和是否完成信号。如果self.done设置为True说明队列没有任务了就执行空步骤反之就是将接受到的data数据应该是动作信息传入self.env.step()函数中获得下一时刻的观测信息、奖励和是否完成信号。
if command == 'step':
observation, reward, done, info = (self.empty_step() if self.done else self.env.step(data))
#### 如果一个episode执行完了self.done设置为False,这说明还需要继续执行因此就self.try_reset()获得观测信息。
if done and (not self.done):
observation = self.try_reset()
#### 将获得的下一时刻的观测信息、奖励和是否完成信号输出到管道中。
self.remote.send((observation, reward, done, self.task_id, info))
#### 如果指令是'reset'说明只是纯粹的重置环境那么就执行self.try_reset()将状态信息发送给管道。
elif command == 'reset':
observation = self.try_reset()
self.remote.send((observation, self.task_id))
#### 如果指令是'reset_task'也就是任务重置那么就执行self.env.unwrapped.reset_task(data)并输出重置成功的标志True。
elif command == 'reset_task':
self.env.unwrapped.reset_task(data)
self.remote.send(True)
#### 如果指令是'close'说明进程要结束了执行self.remote.close()。
elif command == 'close':
self.remote.close()
break
#### 如果指令是'get_spaces'也就是获得空间的观测信息将当前环境的观测信息发出去。如果是其他指令就报异常。
elif command == 'get_spaces':
self.remote.send((self.env.observation_space, self.env.action_space))
else:
raise NotImplementedError()
SubprocVecEnv()
类
class SubprocVecEnv(gym.Env):
#### 这个类应该是创建子进程环境。self.lock应该是设置进程锁保证多进程中只有一个进程是读写数据的。self.remotes和self.work_remotes数据的收发。用EnvWorker()类为每个子进程构建一个小智能体。
def __init__(self, env_factorys, queue_):
"""
:param env_factorys: list of [lambda x: def p: envs.make(env_name), return p], len: num_workers
:param queue:
"""
self.lock = mp.Lock()
# remotes: all recv conn, len: 8, here duplex=True
# works_remotes: all send conn, len: 8, here duplex=True
self.remotes, self.work_remotes = zip(*[mp.Pipe() for _ in env_factorys])
# queue and lock is shared.
self.workers = [EnvWorker(remote, env_fn, queue_, self.lock)
for (remote, env_fn) in zip(self.work_remotes, env_factorys)]
#### 在for循环里面有依次使能一个智能体这样。
# start 8 processes to interact with environments.
for worker in self.workers:
worker.daemon = True
worker.start()
for remote in self.work_remotes:
remote.close()
self.waiting = False # for step_async
self.closed = False
#### 看作者的注释说既然父进程需要跟子进程联系那么需要用一个方式传递这些数据。在这里使用mp.Pipe()类来收发数据。将收到的数据解耦赋值给self.observation_space和self.action_space。
# Since the main process need talk to children processes, we need a way to comunicate between these.
# here we use mp.Pipe() to send/recv data.
self.remotes[0].send(('get_spaces', None))
observation_space, action_space = self.remotes[0].recv()
self.observation_space = observation_space
self.action_space = action_space
#### 等待每个子进程环境下的运行结果输出的结果就是self.step_wait()的结果也就是一个时间步下面的状态、奖励、是否完成的信息。
def step(self, actions):
"""
step synchronously
:param actions:
:return:
"""
self.step_async(actions)
# wait until step state overdue
return self.step_wait()
#### 将每个进程和实时动作信息打包发送给各个子进程然后打上self.waiting = True标签。
def step_async(self, actions):
"""
step asynchronouly
:param actions:
:return:
"""
# let each sub-process step
for remote, action in zip(self.remotes, actions):
remote.send(('step', action))
self.waiting = True
#### 收集每个远程子进程的数据保存到results中self.waiting设置成False将results的内容分解出来得到下一个时间步的观测、奖励信号、是否完成、任务号和其他信息。最后将这些观测信息又拼接起来。
def step_wait(self):
results = [remote.recv() for remote in self.remotes]
self.waiting = False
observations, rewards, dones, task_ids, infos = zip(*results)
return np.stack(observations), np.stack(rewards), np.stack(dones), task_ids, infos
#### 同步地重置环境。将重置环境的结果保存在results中。解耦合results可以得到每个任务的重置观测和任务序列号task_ids。最后整合所有的初始观测和任务序列号。
def reset(self):
"""
reset synchronously
:return:
"""
for remote in self.remotes:
remote.send(('reset', None))
results = [remote.recv() for remote in self.remotes]
observations, task_ids = zip(*results)
return np.stack(observations), task_ids
#### 重置整个任务输出的是重置任务后的所有数据。
def reset_task(self, tasks):
for remote, task in zip(self.remotes, tasks):
remote.send(('reset_task', task))
return np.stack([remote.recv() for remote in self.remotes])
#### 关闭一系列子进程。如果已经是关闭的了不用执行直接返回。如果是self.waiting==True先接受每个子进程的数据然后对每个子进程输出‘close’结束的标志最后关闭。
def close(self):
if self.closed:
return
if self.waiting: # cope with step_async()
for remote in self.remotes:
remote.recv()
for remote in self.remotes:
remote.send(('close', None))
for worker in self.workers:
worker.join()
self.closed = True
总结
这个类应该是创建子进程的过程可能这样高效一些。
具体来说是异步发送指令信号同步接受并执行最后父进程依次收取子进程的信息并打包起来。
这里有些库涉及到了 multiprocessing
这个库所以还需要调其他文档再理解一下。