From 2580654b88717bbd433dc5d839e068733db76aad Mon Sep 17 00:00:00 2001 From: Marcel Gern Date: Mon, 11 Mar 2024 11:32:26 +0100 Subject: [PATCH] added ResoBee example --- resobee/resobee_example.ipynb | 191 ++++++++++++++++++++++++++++++++++ 1 file changed, 191 insertions(+) create mode 100644 resobee/resobee_example.ipynb diff --git a/resobee/resobee_example.ipynb b/resobee/resobee_example.ipynb new file mode 100644 index 0000000..441a742 --- /dev/null +++ b/resobee/resobee_example.ipynb @@ -0,0 +1,191 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import swarmrl.engine.resobee as resobee\n", + "import os\n", + "infomsg = \"I \"\n", + "\n", + "import flax.linen as nn\n", + "import numpy as np\n", + "import optax\n", + "import yaml\n", + " \n", + "import swarmrl as srl\n", + "from swarmrl.actions.actions import Action\n", + "\n", + "import matplotlib.pyplot as plt" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## RL Configuration" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In this code block the task and parameters are defined. Therefor here the goal of the RL procedure is determined." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "\n", + "class ActoCriticNet(nn.Module):\n", + " \"\"\"A simple dense model.\"\"\"\n", + "\n", + " @nn.compact\n", + " def __call__(self, x):\n", + " y=nn.Dense(features=12)(x)\n", + " x = nn.Dense(features=12)(x)\n", + " x = nn.relu(x)\n", + " y=nn.relu(y)\n", + "\n", + " y=nn.Dense(features=12)(y)\n", + " x = nn.Dense(features=12)(x)\n", + " x = nn.relu(x)\n", + " y=nn.relu(y)\n", + " y = nn.Dense(features=1)(x) #Critic\n", + " x = nn.Dense(features=4)(x) #Actor\n", + " return x, y\n", + "\n", + "# Define an exploration policy\n", + "exploration_policy = srl.exploration_policies.RandomExploration(probability=0.1)\n", + "\n", + "# Define a sampling_strategy\n", + "sampling_strategy = srl.sampling_strategies.GumbelDistribution()\n", + "\n", + "# Value function to use\n", + "value_function = srl.value_functions.ExpectedReturns(\n", + " gamma=0.1, standardize=True\n", + ")\n", + "\n", + "#Define the model\n", + "actor_critic = ActoCriticNet()\n", + "network = srl.networks.FlaxModel(\n", + " flax_model=actor_critic,\n", + " optimizer=optax.adam(learning_rate=0.01),\n", + " input_shape=(2,),\n", + " sampling_strategy=sampling_strategy,\n", + " exploration_policy=exploration_policy,\n", + " )\n", + "\n", + "def scale_function(distance: float):\n", + " \"\"\"\n", + " Scaling function for the task\n", + " \"\"\"\n", + " return 1 - distance\n", + "\n", + "task = srl.tasks.searching.GradientSensing(\n", + " source=np.array([10.0, 10.0]),\n", + " decay_function=scale_function,\n", + " reward_scale_factor=100,\n", + " box_length=np.array([20.0, 20.]),\n", + ")\n", + "\n", + "observable=srl.observables.PositionObservable(np.array([20.0,20.0]))\n", + "\n", + "# Define the loss model\n", + "loss = srl.losses.PolicyGradientLoss(value_function=value_function)\n", + "actions = {\n", + " \"TranslateLeft\": Action(force=10.0, new_direction=np.array([-10., 0.])),\n", + " \"TranslateUp\": Action(force=10.0, new_direction=np.array([0., 10.])),\n", + " \"TranslateRight\": Action(force=10.0, new_direction=np.array([10., 0.])),\n", + " \"TranslateDown\": Action(force=10.0, new_direction=np.array([0., -10.])),\n", + "}\n", + "protocol=srl.agents.ActorCriticAgent(particle_type=0, network=network, task=task, observable=observable,actions=actions,loss=loss\n", + " )\n", + "\n", + "# Define the force model.\n", + "rl_trainer=srl.trainers.EpisodicTrainer([protocol])\n", + "n_episodes=100\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Paths to the ResoBee root directory needs to be specified. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "resobee_root_path = \"\"\n", + "\n", + "build_path = os.path.join(resobee_root_path, \"build\")\n", + "config_dir = os.path.join(resobee_root_path, 'workflow/projects/debug/parameter-combination-0/seed-0')\n", + "\n", + "target = 'many_body_simulation'\n", + "resobee_executable = os.path.join(resobee_root_path, 'build/src', target)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "system_runner = resobee.ResoBee(\n", + " resobee_executable=resobee_executable,\n", + " config_dir=config_dir\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "def get_engine(system):\n", + " return system\n", + "reward=rl_trainer.perform_rl_training(get_engine=get_engine, \n", + " system=system_runner, \n", + " n_episodes=n_episodes, \n", + " episode_length=1)\n", + "plt.plot(reward)\n", + "plt.xlabel(\"epsiodes\")\n", + "plt.ylabel(\"reward\")\n", + "plt.show()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.13" + } + }, + "nbformat": 4, + "nbformat_minor": 1 +}