In this tutorial, we build a security-critical reinforcement learning pipeline that learns from fully deterministic, offline data instead of live exploration. We design a custom environment, generate a behavior dataset from a restricted policy, and then train both a behavior cloning baseline and a conservative Q-learning agent using d3rlpy. By structuring the workflow around offline datasets, careful evaluation, and conservative learning objectives, we demonstrate how robust decision-making policies can be trained in settings where unsupervised exploration is not an option. check it out full code here.
!pip -q install -U "d3rlpy" "gymnasium" "numpy" "torch" "matplotlib" "scikit-learn"
import os
import time
import random
import inspect
import numpy as np
import matplotlib.pyplot as plt
import gymnasium as gym
from gymnasium import spaces
import torch
import d3rlpy
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
def pick_device():
if torch.cuda.is_available():
return "cuda:0"
return "cpu"
DEVICE = pick_device()
print("d3rlpy:", getattr(d3rlpy, "__version__", "unknown"), "| torch:", torch.__version__, "| device:", DEVICE)
def make_config(cls, **kwargs):
sig = inspect.signature(cls.__init__)
allowed = set(sig.parameters.keys())
allowed.discard("self")
filtered = k: v for k, v in kwargs.items() if k in allowed
return cls(**filtered)
We set up the environment by installing dependencies, importing libraries, and fine-tuning random seeds for reproducibility. We locate and configure compute tools to ensure consistent performance across all systems. We also define a utility to safely create configuration objects across different d3rlpy versions. check it out full code here.
class SafetyCriticalGridWorld(gym.Env):
metadata = "render_modes": []
def __init__(
self,
size=15,
max_steps=80,
hazard_coords=None,
start=(0, 0),
goal=None,
slip_prob=0.05,
seed=0,
):
super().__init__()
self.size = int(size)
self.max_steps = int(max_steps)
self.start = tuple(start)
self.goal = tuple(goal) if goal is not None else (self.size - 1, self.size - 1)
self.slip_prob = float(slip_prob)
if hazard_coords is None:
hz = set()
rng = np.random.default_rng(seed)
for _ in range(max(1, self.size // 2)):
x = rng.integers(2, self.size - 2)
y = rng.integers(2, self.size - 2)
hz.add((int(x), int(y)))
self.hazards = hz
else:
self.hazards = set(tuple(x) for x in hazard_coords)
self.action_space = spaces.Discrete(4)
self.observation_space = spaces.Box(low=0.0, high=float(self.size - 1), shape=(2,), dtype=np.float32)
self._rng = np.random.default_rng(seed)
self._pos = None
self._t = 0
def reset(self, *, seed=None, options=None):
if seed is not None:
self._rng = np.random.default_rng(seed)
self._pos = [int(self.start[0]), int(self.start[1])]
self._t = 0
obs = np.array(self._pos, dtype=np.float32)
return obs,
def _clip(self):
self._pos[0] = int(np.clip(self._pos[0], 0, self.size - 1))
self._pos[1] = int(np.clip(self._pos[1], 0, self.size - 1))
def step(self, action):
self._t += 1
a = int(action)
if self._rng.random() < self.slip_prob:
a = int(self._rng.integers(0, 4))
if a == 0:
self._pos[1] += 1
elif a == 1:
self._pos[0] += 1
elif a == 2:
self._pos[1] -= 1
elif a == 3:
self._pos[0] -= 1
self._clip()
x, y = int(self._pos[0]), int(self._pos[1])
terminated = False
truncated = self._t >= self.max_steps
reward = -1.0
if (x, y) in self.hazards:
reward = -100.0
terminated = True
if (x, y) == self.goal:
reward = +50.0
terminated = True
obs = np.array([x, y], dtype=np.float32)
return obs, float(reward), terminated, truncated,
We define a safety-critical gridworld environment with threats, terminal conditions, and stochastic transitions. We code in penalties for unsafe situations and rewards for successful task completion. We ensure that the environment tightly controls mobility to reflect real-world security constraints. check it out full code here.
def safe_behavior_policy(obs, env: SafetyCriticalGridWorld, epsilon=0.15):
x, y = int(obs[0]), int(obs[1])
gx, gy = env.goal
preferred = []
if gx > x:
preferred.append(1)
elif gx < x:
preferred.append(3)
if gy > y:
preferred.append(0)
elif gy < y:
preferred.append(2)
if len(preferred) == 0:
preferred = [int(env._rng.integers(0, 4))]
if env._rng.random() < epsilon:
return int(env._rng.integers(0, 4))
candidates = []
for a in preferred:
nx, ny = x, y
if a == 0:
ny += 1
elif a == 1:
nx += 1
elif a == 2:
ny -= 1
elif a == 3:
nx -= 1
nx = int(np.clip(nx, 0, env.size - 1))
ny = int(np.clip(ny, 0, env.size - 1))
if (nx, ny) not in env.hazards:
candidates.append(a)
if len(candidates) == 0:
return preferred[0]
return int(random.choice(candidates))
def generate_offline_episodes(env, n_episodes=400, epsilon=0.20, seed=0):
episodes = []
for i in range(n_episodes):
obs, _ = env.reset(seed=int(seed + i))
obs_list = []
act_list = []
rew_list = []
done_list = []
done = False
while not done:
a = safe_behavior_policy(obs, env, epsilon=epsilon)
nxt, r, terminated, truncated, _ = env.step(a)
done = bool(terminated or truncated)
obs_list.append(np.array(obs, dtype=np.float32))
act_list.append(np.array([a], dtype=np.int64))
rew_list.append(np.array([r], dtype=np.float32))
done_list.append(np.array([1.0 if done else 0.0], dtype=np.float32))
obs = nxt
episodes.append(
"observations": np.stack(obs_list, axis=0),
"actions": np.stack(act_list, axis=0),
"rewards": np.stack(rew_list, axis=0),
"terminals": np.stack(done_list, axis=0),
)
return episodes
def build_mdpdataset(episodes):
obs = np.concatenate([ep["observations"] for ep in episodes], axis=0).astype(np.float32)
acts = np.concatenate([ep["actions"] for ep in episodes], axis=0).astype(np.int64)
rews = np.concatenate([ep["rewards"] for ep in episodes], axis=0).astype(np.float32)
terms = np.concatenate([ep["terminals"] for ep in episodes], axis=0).astype(np.float32)
if hasattr(d3rlpy, "dataset") and hasattr(d3rlpy.dataset, "MDPDataset"):
return d3rlpy.dataset.MDPDataset(observations=obs, actions=acts, rewards=rews, terminals=terms)
raise RuntimeError("d3rlpy.dataset.MDPDataset not found. Upgrade d3rlpy.")
We design a restricted behavior policy that generates offline data without risky exploration. We apply this strategy to collect trajectories and structure them into episodes. We then convert these episodes into a format compatible with d3rlpy’s offline learning API. check it out full code here.
def _get_episodes_from_dataset(dataset):
if hasattr(dataset, "episodes") and dataset.episodes is not None:
return dataset.episodes
if hasattr(dataset, "get_episodes"):
return dataset.get_episodes()
raise AttributeError("Could not find episodes in dataset (d3rlpy version mismatch).")
def _iter_all_observations(dataset):
for ep in _get_episodes_from_dataset(dataset):
obs = getattr(ep, "observations", None)
if obs is None:
continue
for o in obs:
yield o
def _iter_all_transitions(dataset):
for ep in _get_episodes_from_dataset(dataset):
obs = getattr(ep, "observations", None)
acts = getattr(ep, "actions", None)
rews = getattr(ep, "rewards", None)
if obs is None or acts is None:
continue
n = min(len(obs), len(acts))
for i in range(n):
o = obs[i]
a = acts[i]
r = rews[i] if rews is not None and i < len(rews) else None
yield o, a, r
def visualize_dataset(dataset, env, title="Offline Dataset"):
state_visits = np.zeros((env.size, env.size), dtype=np.float32)
for obs in _iter_all_observations(dataset):
x, y = int(obs[0]), int(obs[1])
x = int(np.clip(x, 0, env.size - 1))
y = int(np.clip(y, 0, env.size - 1))
state_visits[y, x] += 1
plt.figure(figsize=(6, 5))
plt.imshow(state_visits, origin="lower")
plt.colorbar(label="Visits")
plt.scatter([env.start[0]], [env.start[1]], marker="o", label="start")
plt.scatter([env.goal[0]], [env.goal[1]], marker="*", label="goal")
if len(env.hazards) > 0:
hz = np.array(list(env.hazards), dtype=np.int32)
plt.scatter(hz[:, 0], hz[:, 1], marker="x", label="hazards")
plt.title(f"title — State visitation")
plt.xlabel("x")
plt.ylabel("y")
plt.legend()
plt.show()
rewards = []
for _, _, r in _iter_all_transitions(dataset):
if r is not None:
rewards.append(float(r))
if len(rewards) > 0:
plt.figure(figsize=(6, 4))
plt.hist(rewards, bins=60)
plt.title(f"title — Reward distribution")
plt.xlabel("reward")
plt.ylabel("count")
plt.show()
We implement dataset utilities that correctly iterate through episodes instead of assuming flat arrays. We visualize state visitation to understand coverage and data bias in offline datasets. We also analyze the reward distribution to observe the learning signals available to the agent. check it out full code here.
def rollout_eval(env, algo, n_episodes=25, seed=0):
returns = []
lengths = []
hazard_hits = 0
goal_hits = 0
for i in range(n_episodes):
obs, _ = env.reset(seed=seed + i)
done = False
total = 0.0
steps = 0
while not done:
a = int(algo.predict(np.asarray(obs, dtype=np.float32)[None, ...])[0])
obs, r, terminated, truncated, _ = env.step(a)
total += float(r)
steps += 1
done = bool(terminated or truncated)
if terminated:
x, y = int(obs[0]), int(obs[1])
if (x, y) in env.hazards:
hazard_hits += 1
if (x, y) == env.goal:
goal_hits += 1
returns.append(total)
lengths.append(steps)
return
"return_mean": float(np.mean(returns)),
"return_std": float(np.std(returns)),
"len_mean": float(np.mean(lengths)),
"hazard_rate": float(hazard_hits / max(1, n_episodes)),
"goal_rate": float(goal_hits / max(1, n_episodes)),
"returns": np.asarray(returns, dtype=np.float32),
def action_mismatch_rate_vs_data(dataset, algo, sample_obs=7000, seed=0):
rng = np.random.default_rng(seed)
obs_all = []
act_all = []
for o, a, _ in _iter_all_transitions(dataset):
obs_all.append(np.asarray(o, dtype=np.float32))
act_all.append(int(np.asarray(a).reshape(-1)[0]))
if len(obs_all) >= 80_000:
break
obs_all = np.stack(obs_all, axis=0)
act_all = np.asarray(act_all, dtype=np.int64)
idx = rng.choice(len(obs_all), size=min(sample_obs, len(obs_all)), replace=False)
obs_probe = obs_all[idx]
act_probe_data = act_all[idx]
act_probe_pi = algo.predict(obs_probe).astype(np.int64)
mismatch = (act_probe_pi != act_probe_data).astype(np.float32)
return float(mismatch.mean())
def create_discrete_bc(device):
if hasattr(d3rlpy.algos, "DiscreteBCConfig"):
cls = d3rlpy.algos.DiscreteBCConfig
cfg = make_config(
cls,
learning_rate=3e-4,
batch_size=256,
)
return cfg.create(device=device)
if hasattr(d3rlpy.algos, "DiscreteBC"):
return d3rlpy.algos.DiscreteBC()
raise RuntimeError("DiscreteBC not available in this d3rlpy version.")
def create_discrete_cql(device, conservative_weight=6.0):
if hasattr(d3rlpy.algos, "DiscreteCQLConfig"):
cls = d3rlpy.algos.DiscreteCQLConfig
cfg = make_config(
cls,
learning_rate=3e-4,
actor_learning_rate=3e-4,
critic_learning_rate=3e-4,
temp_learning_rate=3e-4,
alpha_learning_rate=3e-4,
batch_size=256,
conservative_weight=float(conservative_weight),
n_action_samples=10,
rollout_interval=0,
)
return cfg.create(device=device)
if hasattr(d3rlpy.algos, "DiscreteCQL"):
algo = d3rlpy.algos.DiscreteCQL()
if hasattr(algo, "conservative_weight"):
try:
algo.conservative_weight = float(conservative_weight)
except Exception:
pass
return algo
raise RuntimeError("DiscreteCQL not available in this d3rlpy version.")
We define controlled evaluation routines to measure policy performance without uncontrolled exploration. We calculate return and security metrics including risk and target rates. We also introduce a mismatch diagnostic to measure how often learned actions deviate from dataset behavior. check it out full code here.
def main():
env = SafetyCriticalGridWorld(
size=15,
max_steps=80,
slip_prob=0.05,
seed=SEED,
)
raw_eps = generate_offline_episodes(env, n_episodes=500, epsilon=0.22, seed=SEED)
dataset = build_mdpdataset(raw_eps)
print("dataset built:", type(dataset).__name__)
visualize_dataset(dataset, env, title="Behavior Dataset (Offline)")
bc = create_discrete_bc(DEVICE)
cql = create_discrete_cql(DEVICE, conservative_weight=6.0)
print("\nTraining Discrete BC (offline)...")
t0 = time.time()
bc.fit(
dataset,
n_steps=25_000,
n_steps_per_epoch=2_500,
experiment_name="grid_bc_offline",
)
print("BC train sec:", round(time.time() - t0, 2))
print("\nTraining Discrete CQL (offline)...")
t0 = time.time()
cql.fit(
dataset,
n_steps=80_000,
n_steps_per_epoch=8_000,
experiment_name="grid_cql_offline",
)
print("CQL train sec:", round(time.time() - t0, 2))
print("\nControlled online evaluation (small number of rollouts):")
bc_metrics = rollout_eval(env, bc, n_episodes=30, seed=SEED + 1000)
cql_metrics = rollout_eval(env, cql, n_episodes=30, seed=SEED + 2000)
print("BC :", k: v for k, v in bc_metrics.items() if k != "returns")
print("CQL:", k: v for k, v in cql_metrics.items() if k != "returns")
print("\nOOD-ish diagnostic (policy action mismatch vs data action at same states):")
bc_mismatch = action_mismatch_rate_vs_data(dataset, bc, sample_obs=7000, seed=SEED + 1)
cql_mismatch = action_mismatch_rate_vs_data(dataset, cql, sample_obs=7000, seed=SEED + 2)
print("BC mismatch rate :", bc_mismatch)
print("CQL mismatch rate:", cql_mismatch)
plt.figure(figsize=(6, 4))
labels = ["BC", "CQL"]
means = [bc_metrics["return_mean"], cql_metrics["return_mean"]]
stds = [bc_metrics["return_std"], cql_metrics["return_std"]]
plt.bar(labels, means, yerr=stds)
plt.ylabel("Return")
plt.title("Online Rollout Return (Controlled)")
plt.show()
plt.figure(figsize=(6, 4))
plt.plot(np.sort(bc_metrics["returns"]), label="BC")
plt.plot(np.sort(cql_metrics["returns"]), label="CQL")
plt.xlabel("Episode (sorted)")
plt.ylabel("Return")
plt.title("Return Distribution (Sorted)")
plt.legend()
plt.show()
out_dir = "/content/offline_rl_artifacts"
os.makedirs(out_dir, exist_ok=True)
bc_path = os.path.join(out_dir, "grid_bc_policy.pt")
cql_path = os.path.join(out_dir, "grid_cql_policy.pt")
if hasattr(bc, "save_policy"):
bc.save_policy(bc_path)
print("Saved BC policy:", bc_path)
if hasattr(cql, "save_policy"):
cql.save_policy(cql_path)
print("Saved CQL policy:", cql_path)
print("\nDone.")
if __name__ == "__main__":
main()
We train both behavioral cloning and conservative Q-learning agents entirely with offline data. We compare their performance using controlled rollouts and diagnostic metrics. We finalize the workflow by saving the trained policies and summarizing the security-aware learning results.
In conclusion, we demonstrated that Conservative Q-learning provides a more reliable policy than simple imitation when learning from historical data in security-sensitive environments. By comparing offline training outcomes, controlled online assessments, and action-delivery mismatches, we illustrate how conservatism helps reduce risky, out-of-delivery behavior. Overall, we have presented a complete, reproducible offline RL workflow that we can extend to more complex domains such as robotics, healthcare, or finance without compromising security.
check it out full code here. Also, feel free to follow us Twitter And don’t forget to join us 100k+ ml subreddit and subscribe our newsletter. wait! Are you on Telegram? Now you can also connect with us on Telegram.