forked from projectmesa/mesa-examples
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Adding Boltzman Model and WolfSheep Model to Mesa_RL (projectmesa#197)
* Seeding RL Folder * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Formatting Corrections * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Re-formatting * Reformatting * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Minor corrections * Minor corrections * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Adding 2 more examples * Formatting Code * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Improvements * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
- Loading branch information
1 parent
9897884
commit a106204
Showing
14 changed files
with
886 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
# Balancing Wealth Inequality | ||
This folder showcases how to solve the Boltzmann wealth model with Proximal Policy Optimization (PPO) from Stable Baselines. | ||
|
||
## Key features: | ||
|
||
- Boltzmann Wealth Model: Agents with varying wealth navigate a grid, aiming to minimize inequality measured by the Gini coefficient. | ||
- PPO Training: A PPO agent is trained to achieve this goal, receiving sparse rewards based on Gini coefficient improvement and a large terminal reward for achieving low inequality. | ||
- Mesa Data Collection and Visualization: The Mesa data collector tool tracks Gini values during training, allowing for real-time visualization. | ||
- Visualization Script: Visualize the trained agent's behavior with Mesa's visualization tools, presenting agent movement and Gini values within the grid. You can run `server.py` file to test it with pre-trained model. | ||
|
||
## Model Behaviour | ||
As stable baselines controls multiple agents with the same weight, this results in the agents learning to move towards a corner of the grid. These brings all the agents together allowing exchange of money between them resulting in reward maximization. | ||
<p align="center"> | ||
<img src="ppo_agent.gif" width="500" height="400"> | ||
</p> |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,161 @@ | ||
""" | ||
This code implements a multi-agent model called MoneyModel using the Mesa library. | ||
The model simulates the distribution of wealth among agents in a grid environment. | ||
Each agent has a randomly assigned wealth and can move to neighboring cells. | ||
Agents can also give money to other agents in the same cell if they have greater wealth. | ||
The model is trained by a scientist who believes in an equal society and wants to minimize the Gini coefficient, which measures wealth inequality. | ||
The model is trained using the Proximal Policy Optimization (PPO) algorithm from the stable-baselines3 library. | ||
The trained model is saved as "ppo_money_model". | ||
""" | ||
|
||
import random | ||
|
||
import gymnasium | ||
import matplotlib.pyplot as plt | ||
|
||
# Import mesa | ||
import mesa | ||
|
||
# Import necessary libraries | ||
import numpy as np | ||
import seaborn as sns | ||
from mesa_models.boltzmann_wealth_model.model import ( | ||
BoltzmannWealthModel, | ||
MoneyAgent, | ||
compute_gini, | ||
) | ||
|
||
NUM_AGENTS = 10 | ||
|
||
|
||
# Define the agent class | ||
class MoneyAgentRL(MoneyAgent): | ||
def __init__(self, unique_id, model): | ||
super().__init__(unique_id, model) | ||
self.wealth = np.random.randint(1, NUM_AGENTS) | ||
|
||
def move(self, action): | ||
empty_neighbors = self.model.grid.get_neighborhood( | ||
self.pos, moore=True, include_center=False | ||
) | ||
|
||
# Define the movement deltas | ||
moves = { | ||
0: (1, 0), # Move right | ||
1: (-1, 0), # Move left | ||
2: (0, -1), # Move up | ||
3: (0, 1), # Move down | ||
4: (0, 0), # Stay in place | ||
} | ||
|
||
# Get the delta for the action, defaulting to (0, 0) if the action is invalid | ||
dx, dy = moves.get(int(action), (0, 0)) | ||
|
||
# Calculate the new position and wrap around the grid | ||
new_position = ( | ||
(self.pos[0] + dx) % self.model.grid.width, | ||
(self.pos[1] + dy) % self.model.grid.height, | ||
) | ||
|
||
# Move the agent if the new position is in empty_neighbors | ||
if new_position in empty_neighbors: | ||
self.model.grid.move_agent(self, new_position) | ||
|
||
def take_money(self): | ||
# Get all agents in the same cell | ||
cellmates = self.model.grid.get_cell_list_contents([self.pos]) | ||
if len(cellmates) > 1: | ||
# Choose a random agent from the cellmates | ||
other_agent = random.choice(cellmates) | ||
if other_agent.wealth > self.wealth: | ||
# Transfer money from other_agent to self | ||
other_agent.wealth -= 1 | ||
self.wealth += 1 | ||
|
||
def step(self): | ||
# Get the action for the agent | ||
action = self.model.action_dict[self.unique_id] | ||
# Move the agent based on the action | ||
self.move(action) | ||
# Take money from other agents in the same cell | ||
self.take_money() | ||
|
||
|
||
# Define the model class | ||
class BoltzmannWealthModelRL(BoltzmannWealthModel, gymnasium.Env): | ||
def __init__(self, N, width, height): | ||
super().__init__(N, width, height) | ||
# Define the observation and action space for the RL model | ||
# The observation space is the wealth of each agent and their position | ||
self.observation_space = gymnasium.spaces.Box(low=0, high=10 * N, shape=(N, 3)) | ||
# The action space is a MultiDiscrete space with 5 possible actions for each agent | ||
self.action_space = gymnasium.spaces.MultiDiscrete([5] * N) | ||
self.is_visualize = False | ||
|
||
def step(self, action): | ||
self.action_dict = action | ||
# Perform one step of the model | ||
self.schedule.step() | ||
# Collect data for visualization | ||
self.datacollector.collect(self) | ||
# Compute the new Gini coefficient | ||
new_gini = compute_gini(self) | ||
# Compute the reward based on the change in Gini coefficient | ||
reward = self.calculate_reward(new_gini) | ||
self.prev_gini = new_gini | ||
# Get the observation for the RL model | ||
obs = self._get_obs() | ||
if self.schedule.time > 5 * NUM_AGENTS: | ||
# Terminate the episode if the model has run for a certain number of timesteps | ||
done = True | ||
reward = -1 | ||
elif new_gini < 0.1: | ||
# Terminate the episode if the Gini coefficient is below a certain threshold | ||
done = True | ||
reward = 50 / self.schedule.time | ||
else: | ||
done = False | ||
info = {} | ||
truncated = False | ||
return obs, reward, done, truncated, info | ||
|
||
def calculate_reward(self, new_gini): | ||
if new_gini < self.prev_gini: | ||
# Compute the reward based on the decrease in Gini coefficient | ||
reward = (self.prev_gini - new_gini) * 20 | ||
else: | ||
# Penalize for increase in Gini coefficient | ||
reward = -0.05 | ||
self.prev_gini = new_gini | ||
return reward | ||
|
||
def visualize(self): | ||
# Visualize the Gini coefficient over time | ||
gini = self.datacollector.get_model_vars_dataframe() | ||
g = sns.lineplot(data=gini) | ||
g.set(title="Gini Coefficient over Time", ylabel="Gini Coefficient") | ||
plt.show() | ||
|
||
def reset(self, *, seed=None, options=None): | ||
if self.is_visualize: | ||
# Visualize the Gini coefficient before resetting the model | ||
self.visualize() | ||
super().reset() | ||
self.grid = mesa.space.MultiGrid(self.grid.width, self.grid.height, True) | ||
self.schedule = mesa.time.RandomActivation(self) | ||
for i in range(self.num_agents): | ||
# Create MoneyAgentRL instances and add them to the schedule | ||
a = MoneyAgentRL(i, self) | ||
self.schedule.add(a) | ||
x = self.random.randrange(self.grid.width) | ||
y = self.random.randrange(self.grid.height) | ||
self.grid.place_agent(a, (x, y)) | ||
self.prev_gini = compute_gini(self) | ||
return self._get_obs(), {} | ||
|
||
def _get_obs(self): | ||
# The observation is the wealth of each agent and their position | ||
obs = [] | ||
for a in self.schedule.agents: | ||
obs.append([a.wealth, *list(a.pos)]) | ||
return np.array(obs) |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
import os | ||
|
||
import mesa | ||
from mesa.visualization.ModularVisualization import ModularServer | ||
from mesa.visualization.modules import ChartModule | ||
from model import BoltzmannWealthModelRL | ||
from stable_baselines3 import PPO | ||
|
||
|
||
# Modify the MoneyModel class to take actions from the RL model | ||
class MoneyModelRL(BoltzmannWealthModelRL): | ||
def __init__(self, N, width, height): | ||
super().__init__(N, width, height) | ||
model_path = os.path.join( | ||
os.path.dirname(__file__), "..", "model", "boltzmann_money.zip" | ||
) | ||
self.rl_model = PPO.load(model_path) | ||
self.reset() | ||
|
||
def step(self): | ||
# Collect data | ||
self.datacollector.collect(self) | ||
|
||
# Get observations which is the wealth of each agent and their position | ||
obs = self._get_obs() | ||
|
||
action, _states = self.rl_model.predict(obs) | ||
self.action_dict = action | ||
self.schedule.step() | ||
|
||
|
||
# Define the agent portrayal with different colors for different wealth levels | ||
def agent_portrayal(agent): | ||
if agent.wealth > 10: | ||
color = "purple" | ||
elif agent.wealth > 7: | ||
color = "red" | ||
elif agent.wealth > 5: | ||
color = "orange" | ||
elif agent.wealth > 3: | ||
color = "yellow" | ||
else: | ||
color = "blue" | ||
|
||
portrayal = { | ||
"Shape": "circle", | ||
"Filled": "true", | ||
"Layer": 0, | ||
"Color": color, | ||
"r": 0.5, | ||
} | ||
return portrayal | ||
|
||
|
||
if __name__ == "__main__": | ||
# Define a grid visualization | ||
grid = mesa.visualization.CanvasGrid(agent_portrayal, 10, 10, 500, 500) | ||
|
||
# Define a chart visualization | ||
chart = ChartModule( | ||
[{"Label": "Gini", "Color": "Black"}], data_collector_name="datacollector" | ||
) | ||
|
||
# Create a modular server | ||
server = ModularServer( | ||
MoneyModelRL, [grid, chart], "Money Model", {"N": 10, "width": 10, "height": 10} | ||
) | ||
server.port = 8521 # The default | ||
server.launch() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
import argparse | ||
|
||
from model import NUM_AGENTS, BoltzmannWealthModelRL | ||
from stable_baselines3 import PPO | ||
from stable_baselines3.common.callbacks import EvalCallback | ||
|
||
|
||
def rl_model(args): | ||
# Create the environment | ||
env = BoltzmannWealthModelRL(N=NUM_AGENTS, width=NUM_AGENTS, height=NUM_AGENTS) | ||
eval_env = BoltzmannWealthModelRL(N=NUM_AGENTS, width=NUM_AGENTS, height=NUM_AGENTS) | ||
eval_callback = EvalCallback( | ||
eval_env, best_model_save_path="./logs/", log_path="./logs/", eval_freq=5000 | ||
) | ||
# Define the PPO model | ||
model = PPO("MlpPolicy", env, verbose=1, tensorboard_log="./logs/") | ||
|
||
# Train the model | ||
model.learn(total_timesteps=args.stop_timesteps, callback=[eval_callback]) | ||
|
||
# Save the model | ||
model.save("ppo_money_model") | ||
|
||
|
||
if __name__ == "__main__": | ||
# Define the command line arguments | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument( | ||
"--stop-timesteps", | ||
type=int, | ||
default=NUM_AGENTS * 100, | ||
help="Number of timesteps to train.", | ||
) | ||
args = parser.parse_args() | ||
rl_model(args) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
# Collaborative Survival: Wolf-Sheep Predation Model | ||
|
||
This project demonstrates the use of the RLlib library to implement Multi-Agent Reinforcement Learning (MARL) in the classic Wolf-Sheep predation problem. The environment details can be found on the Mesa project's GitHub repository [here](https://github.com/projectmesa/mesa-examples/tree/main/examples/wolf_sheep). | ||
|
||
## Key Features | ||
|
||
**RLlib and Multi-Agent Learning**: | ||
- **Library Utilized**: The project leverages the RLlib library to concurrently train two independent PPO (Proximal Policy Optimization) agents. | ||
- **Agents**: | ||
- **Wolf**: Predatory agent survives by eating sheeps | ||
- **Sheep**: Prey agent survives by eating grass | ||
- **Grass**: Grass is eaten by sheep and regrows with time | ||
|
||
**Input and Observation Space**: | ||
- **Observation Grid**: Each agent's policy receives a 10x10 grid centered on itself as input. | ||
- **Grid Details**: The grid incorporates information about the presence of other agents (wolves, sheep, and grass) within the grid. | ||
- **Agent's Energy Level**: The agent's current energy level is also included in the observations. | ||
|
||
**Action Space**: | ||
- **Action Space**: The action space is the ID of the neighboring tile to which the agent wants to move. | ||
|
||
**Behavior and Training Outcomes**: | ||
- **Optimal Behavior**: | ||
- **Wolf**: Learns to move towards the nearest sheep. | ||
- **Sheep**: Learns to run away from wolves and is attracted to grass. | ||
- **Density Variations**: You can vary the densities of sheep and wolves to observe different results. | ||
|
||
By leveraging RLlib and Multi-Agent Learning, this project provides insights into the dynamics of predator-prey relationships and optimal behavior strategies in a simulated environment. | ||
|
||
|
||
<p align="center"> | ||
<img src="resources/wolf_sheep.gif" width="500" height="400"> | ||
</p> |
Oops, something went wrong.