Multi-Agent Simulated Environment: Petting Zoo
In this example, we show how to define multi-agent simulations with simulated environments. Like ours single-agent example with Gymnasium, we create an agent-environment loop with an externally defined environment. The main difference is that we now implement this kind of interaction loop with multiple agents instead. We will use the Petting Zoo library, which is the multi-agent counterpart to Gymnasium.
Install pettingzoo
and other dependencies
!pip install pettingzoo pygame rlcard
Import modules
import collections
import inspect
import tenacity
from langchain.chat_models import ChatOpenAI
from langchain.schema import (
HumanMessage,
SystemMessage,
)
from langchain.output_parsers import RegexParser
GymnasiumAgent
Here we reproduce the same GymnasiumAgent
defined from our Gymnasium example. If after multiple retries it does not take a valid action, it simply takes a random action.
class GymnasiumAgent:
@classmethod
def get_docs(cls, env):
return env.unwrapped.__doc__
def __init__(self, model, env):
self.model = model
self.env = env
self.docs = self.get_docs(env)
self.instructions = """
Your goal is to maximize your return, i.e. the sum of the rewards you receive.
I will give you an observation, reward, terminiation flag, truncation flag, and the return so far, formatted as:
Observation: <observation>
Reward: <reward>
Termination: <termination>
Truncation: <truncation>
Return: <sum_of_rewards>
You will respond with an action, formatted as:
Action: <action>
where you replace <action> with your actual action.
Do nothing else but return the action.
"""
self.action_parser = RegexParser(
regex=r"Action: (.*)", output_keys=["action"], default_output_key="action"
)
self.message_history = []
self.ret = 0
def random_action(self):
action = self.env.action_space.sample()
return action
def reset(self):
self.message_history = [
SystemMessage(content=self.docs),
SystemMessage(content=self.instructions),
]
def observe(self, obs, rew=0, term=False, trunc=False, info=None):
self.ret += rew
obs_message = f"""
Observation: {obs}
Reward: {rew}
Termination: {term}
Truncation: {trunc}
Return: {self.ret}
"""
self.message_history.append(HumanMessage(content=obs_message))
return obs_message
def _act(self):
act_message = self.model(self.message_history)
self.message_history.append(act_message)
action = int(self.action_parser.parse(act_message.content)["action"])
return action
def act(self):
try:
for attempt in tenacity.Retrying(
stop=tenacity.stop_after_attempt(2),
wait=tenacity.wait_none(), # No waiting time between retries
retry=tenacity.retry_if_exception_type(ValueError),
before_sleep=lambda retry_state: print(
f"ValueError occurred: {retry_state.outcome.exception()}, retrying..."
),
):
with attempt:
action = self._act()
except tenacity.RetryError as e:
action = self.random_action()
return action
Main loop
def main(agents, env):
env.reset()
for name, agent in agents.items():
agent.reset()
for agent_name in env.agent_iter():
observation, reward, termination, truncation, info = env.last()
obs_message = agents[agent_name].observe(
observation, reward, termination, truncation, info
)
print(obs_message)
if termination or truncation:
action = None
else:
action = agents[agent_name].act()
print(f"Action: {action}")
env.step(action)
env.close()
PettingZooAgent
The PettingZooAgent
extends the GymnasiumAgent
to the multi-agent setting. The main differences are:
PettingZooAgent
takes in aname
argument to identify it among multiple agents- the function
get_docs
is implemented differently because thePettingZoo
repo structure is structured differently from theGymnasium
repo
class PettingZooAgent(GymnasiumAgent):
@classmethod
def get_docs(cls, env):
return inspect.getmodule(env.unwrapped).__doc__
def __init__(self, name, model, env):
super().__init__(model, env)
self.name = name
def random_action(self):
action = self.env.action_space(self.name).sample()
return action
Rock, Paper, Scissors
We can now run a simulation of a multi-agent rock, paper, scissors game using the PettingZooAgent
.
from pettingzoo.classic import rps_v2
env = rps_v2.env(max_cycles=3, render_mode="human")
agents = {
name: PettingZooAgent(name=name, model=ChatOpenAI(temperature=1), env=env)
for name in env.possible_agents
}
main(agents, env)
Observation: 3
Reward: 0
Termination: False
Truncation: False
Return: 0
Action: 1
Observation: 3
Reward: 0
Termination: False
Truncation: False
Return: 0
Action: 1
Observation: 1
Reward: 0
Termination: False
Truncation: False
Return: 0
Action: 2
Observation: 1
Reward: 0
Termination: False
Truncation: False
Return: 0
Action: 1
Observation: 1
Reward: 1
Termination: False
Truncation: False
Return: 1
Action: 0
Observation: 2
Reward: -1
Termination: False
Truncation: False
Return: -1
Action: 0
Observation: 0
Reward: 0
Termination: False
Truncation: True
Return: 1
Action: None
Observation: 0
Reward: 0
Termination: False
Truncation: True
Return: -1
Action: None
ActionMaskAgent
Some PettingZoo
environments provide an action_mask
to tell the agent which actions are valid. The ActionMaskAgent
subclasses PettingZooAgent
to use information from the action_mask
to select actions.
class ActionMaskAgent(PettingZooAgent):
def __init__(self, name, model, env):
super().__init__(name, model, env)
self.obs_buffer = collections.deque(maxlen=1)
def random_action(self):
obs = self.obs_buffer[-1]
action = self.env.action_space(self.name).sample(obs["action_mask"])
return action
def reset(self):
self.message_history = [
SystemMessage(content=self.docs),
SystemMessage(content=self.instructions),
]
def observe(self, obs, rew=0, term=False, trunc=False, info=None):
self.obs_buffer.append(obs)
return super().observe(obs, rew, term, trunc, info)
def _act(self):
valid_action_instruction = "Generate a valid action given by the indices of the `action_mask` that are not 0, according to the action formatting rules."
self.message_history.append(HumanMessage(content=valid_action_instruction))
return super()._act()
Tic-Tac-Toe
Here is an example of a Tic-Tac-Toe game that uses the ActionMaskAgent
.
from pettingzoo.classic import tictactoe_v3
env = tictactoe_v3.env(render_mode="human")
agents = {
name: ActionMaskAgent(name=name, model=ChatOpenAI(temperature=0.2), env=env)
for name in env.possible_agents
}
main(agents, env)
Observation: {'observation': array([[[0, 0],
[0, 0],
[0, 0]],
[[0, 0],
[0, 0],
[0, 0]],
[[0, 0],
[0, 0],
[0, 0]]], dtype=int8), 'action_mask': array([1, 1, 1, 1, 1, 1, 1, 1, 1], dtype=int8)}
Reward: 0
Termination: False
Truncation: False
Return: 0
Action: 0
| |
X | - | -
_____|_____|_____
| |
- | - | -
_____|_____|_____
| |
- | - | -
| |
Observation: {'observation': array([[[0, 1],
[0, 0],
[0, 0]],
[[0, 0],
[0, 0],
[0, 0]],
[[0, 0],
[0, 0],
[0, 0]]], dtype=int8), 'action_mask': array([0, 1, 1, 1, 1, 1, 1, 1, 1], dtype=int8)}
Reward: 0
Termination: False
Truncation: False
Return: 0
Action: 1
| |
X | - | -
_____|_____|_____
| |
O | - | -
_____|_____|_____
| |
- | - | -
| |
Observation: {'observation': array([[[1, 0],
[0, 1],
[0, 0]],
[[0, 0],
[0, 0],
[0, 0]],
[[0, 0],
[0, 0],
[0, 0]]], dtype=int8), 'action_mask': array([0, 0, 1, 1, 1, 1, 1, 1, 1], dtype=int8)}
Reward: 0
Termination: False
Truncation: False
Return: 0
Action: 2
| |
X | - | -
_____|_____|_____
| |
O | - | -
_____|_____|_____
| |
X | - | -
| |
Observation: {'observation': array([[[0, 1],
[1, 0],
[0, 1]],
[[0, 0],
[0, 0],
[0, 0]],
[[0, 0],
[0, 0],
[0, 0]]], dtype=int8), 'action_mask': array([0, 0, 0, 1, 1, 1, 1, 1, 1], dtype=int8)}
Reward: 0
Termination: False
Truncation: False
Return: 0
Action: 3
| |
X | O | -
_____|_____|_____
| |
O | - | -
_____|_____|_____
| |
X | - | -
| |
Observation: {'observation': array([[[1, 0],
[0, 1],
[1, 0]],
[[0, 1],
[0, 0],
[0, 0]],
[[0, 0],
[0, 0],
[0, 0]]], dtype=int8), 'action_mask': array([0, 0, 0, 0, 1, 1, 1, 1, 1], dtype=int8)}
Reward: 0
Termination: False
Truncation: False
Return: 0
Action: 4
| |
X | O | -
_____|_____|_____
| |
O | X | -
_____|_____|_____
| |
X | - | -
| |
Observation: {'observation': array([[[0, 1],
[1, 0],
[0, 1]],
[[1, 0],
[0, 1],
[0, 0]],
[[0, 0],
[0, 0],
[0, 0]]], dtype=int8), 'action_mask': array([0, 0, 0, 0, 0, 1, 1, 1, 1], dtype=int8)}
Reward: 0
Termination: False
Truncation: False
Return: 0
Action: 5
| |
X | O | -
_____|_____|_____
| |
O | X | -
_____|_____|_____
| |
X | O | -
| |
Observation: {'observation': array([[[1, 0],
[0, 1],
[1, 0]],
[[0, 1],
[1, 0],
[0, 1]],
[[0, 0],
[0, 0],
[0, 0]]], dtype=int8), 'action_mask': array([0, 0, 0, 0, 0, 0, 1, 1, 1], dtype=int8)}
Reward: 0
Termination: False
Truncation: False
Return: 0
Action: 6
| |
X | O | X
_____|_____|_____
| |
O | X | -
_____|_____|_____
| |
X | O | -
| |
Observation: {'observation': array([[[0, 1],
[1, 0],
[0, 1]],
[[1, 0],
[0, 1],
[1, 0]],
[[0, 1],
[0, 0],
[0, 0]]], dtype=int8), 'action_mask': array([0, 0, 0, 0, 0, 0, 0, 1, 1], dtype=int8)}
Reward: -1
Termination: True
Truncation: False
Return: -1
Action: None
Observation: {'observation': array([[[1, 0],
[0, 1],
[1, 0]],
[[0, 1],
[1, 0],
[0, 1]],
[[1, 0],
[0, 0],
[0, 0]]], dtype=int8), 'action_mask': array([0, 0, 0, 0, 0, 0, 0, 1, 1], dtype=int8)}
Reward: 1
Termination: True
Truncation: False
Return: 1
Action: None
Texas Hold'em No Limit
Here is an example of a Texas Hold'em No Limit game that uses the ActionMaskAgent
.
from pettingzoo.classic import texas_holdem_no_limit_v6
env = texas_holdem_no_limit_v6.env(num_players=4, render_mode="human")
agents = {
name: ActionMaskAgent(name=name, model=ChatOpenAI(temperature=0.2), env=env)
for name in env.possible_agents
}
main(agents, env)
Observation: {'observation': array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 1., 0., 0., 0., 0.,
0., 0., 2.], dtype=float32), 'action_mask': array([1, 1, 0, 1, 1], dtype=int8)}
Reward: 0
Termination: False
Truncation: False
Return: 0
Action: 1
Observation: {'observation': array([0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0.,
0., 0., 2.], dtype=float32), 'action_mask': array([1, 1, 0, 1, 1], dtype=int8)}
Reward: 0
Termination: False
Truncation: False
Return: 0
Action: 1
Observation: {'observation': array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 1.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 1., 2.], dtype=float32), 'action_mask': array([1, 1, 1, 1, 1], dtype=int8)}
Reward: 0
Termination: False
Truncation: False
Return: 0
Action: 1
Observation: {'observation': array([0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 2., 2.], dtype=float32), 'action_mask': array([1, 1, 1, 1, 1], dtype=int8)}
Reward: 0
Termination: False
Truncation: False
Return: 0
Action: 0
Observation: {'observation': array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 1.,
0., 0., 0., 0., 0., 1., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 2., 2.], dtype=float32), 'action_mask': array([1, 1, 1, 1, 1], dtype=int8)}
Reward: 0
Termination: False
Truncation: False
Return: 0
Action: 2
Observation: {'observation': array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 1., 0., 0., 1., 1., 0., 0., 1., 0., 0., 0., 0.,
0., 2., 6.], dtype=float32), 'action_mask': array([1, 1, 1, 1, 1], dtype=int8)}
Reward: 0
Termination: False
Truncation: False
Return: 0
Action: 2
Observation: {'observation': array([0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 1., 0., 0., 1., 0., 0., 0., 0., 0., 1., 0., 0.,
0., 2., 8.], dtype=float32), 'action_mask': array([1, 1, 1, 1, 1], dtype=int8)}
Reward: 0
Termination: False
Truncation: False
Return: 0
Action: 3
Observation: {'observation': array([ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 1., 0., 0., 1., 0., 0., 0., 0., 0.,
1., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
6., 20.], dtype=float32), 'action_mask': array([1, 1, 1, 1, 1], dtype=int8)}
Reward: 0
Termination: False
Truncation: False
Return: 0
Action: 4
Observation: {'observation': array([ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 1., 0., 0., 1., 1.,
0., 0., 1., 0., 0., 0., 0., 0., 8., 100.],
dtype=float32), 'action_mask': array([1, 1, 0, 0, 0], dtype=int8)}
Reward: 0
Termination: False
Truncation: False
Return: 0
Action: 4
[WARNING]: Illegal move made, game terminating with current player losing.
obs['action_mask'] contains a mask of all legal moves that can be chosen.
Observation: {'observation': array([ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 1., 0., 0., 1., 1.,
0., 0., 1., 0., 0., 0., 0., 0., 8., 100.],
dtype=float32), 'action_mask': array([1, 1, 0, 0, 0], dtype=int8)}
Reward: -1.0
Termination: True
Truncation: True
Return: -1.0
Action: None
Observation: {'observation': array([ 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 1., 0., 0., 1., 0.,
0., 0., 0., 0., 1., 0., 0., 0., 20., 100.],
dtype=float32), 'action_mask': array([1, 1, 0, 0, 0], dtype=int8)}
Reward: 0
Termination: True
Truncation: True
Return: 0
Action: None
Observation: {'observation': array([ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0.,
1., 0., 0., 0., 0., 0., 1., 0., 0., 1., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 100., 100.],
dtype=float32), 'action_mask': array([1, 1, 0, 0, 0], dtype=int8)}
Reward: 0
Termination: True
Truncation: True
Return: 0
Action: None
Observation: {'observation': array([ 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 1., 1., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 1., 0., 0., 1., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 2., 100.],
dtype=float32), 'action_mask': array([1, 1, 0, 0, 0], dtype=int8)}
Reward: 0
Termination: True
Truncation: True
Return: 0
Action: None