What matters in reinforcement learning for tractography
Reference: Théberge, A., Desrosiers, C., Boré, A., Descoteaux, M., & Jodoin, P. M. (2024). What matters in reinforcement learning for tractography. Medical Image Analysis, 93, 103085.
Reviewed by Jeremi Levesque. https://arxiv.org/pdf/2305.09041.pdf
Table of Contents
Key recommendations:
- SAC Auto is the better performing algorithm
- Use fODFs, no WM mask, two previous directions for the state
- Seed from the interface if bundle coverage isn’t priority
- Train on ISMRM2015 dataset (to generalize on in-vivo)
- Perform extensive params search before training to select the agent
- Vary the step size after training to accomodate differences in voxel sizes
- Maybe add a small termination (bonus) reward.
Abstract #
- DeepRL to learn the tractography procedure & train agents to reconstruct the structure of WM without manually curated streamlines.
- Used ~7,400 models/41,000 hours of GPU to explore and guide researchers in what works and what doesn’t
- Recommendations of RL algorithm choice, inputs to the agents, reward function, & more.
Introduction #
- Enough streamlines -> approx. of the white matter tissue
- Initial seeds in white matter or interface
- Tractography is an Ill-posed problem: tried to reconstruct the global connectivity from local information. Global and other methods still require efforts to resolve the problem.
- ML tractography algorithms are trained with a gold standard (as a reference) with raw diffusion data so they can learn the entire tractography procedure and hopefully improve on classical methods.
- ML methods show improvements on classical methods, but very limited by their data (synthetic phantoms => doesn’t represent reality correctly). Tractograms can be annotated by experts, but very time consuming and might even have some imprecision due to the tractograms generated by an imperfect gold standard.
- Track-to-Learn (TTL): only RL for tractography framework available today. High competitive results for in-vivo and in-silico datasets.
- Sections on this article:
- Pitfalls to avoid and recommandations to train agents with TTL.
- Performance of RL algorithms on tractography
- Performance/correctness with different input signals
- Performance/correctness with different reward functions
- Open-source framework
Related works #
- Wanyan et al. implemented first method using RL that builds expanding graphs that represent paths between the beginning and ending nodes of a bundle. TD learning to evaluate the nodes of the graphs and agents rewarded when successfully connecting the beginning and ending regions of a bundle. Unclear how graphs are propagated, how the spatial positions are defined as nodes and it assumes that we have a starting and ending endpoint, which is not generally the case.
- TTL is the first functionnal deep-RL tractography method. Views the tractography problem as if a robot was dropped inside a maze and the path followed by the robot would be the produced streamline. Diffusion signal is the state space, reward is the alignement of the streamline and the peaks and the alignement between streamline segments (alignement of each step?).
- Two studies covered different insights and practical recommendations with on-policy methods and offline learning with expert data (from humans).
Method #
Preliminaries #
Tractography #
- Parameters
- \(p_0\): starting position
- seed: initial position
- \(\Delta\): step size
- \(v(p_0)\): direction from local model (starting position in this case) For the first half of the streamline (forward propagation), until max_length or stopping criterion: $$p_{t+1} = p_t + \Delta v(p_t)$$ For the other half-streamline (backwards propagation): $$p_{t+1} = p_t - \Delta v(p_t)$$ Tracking initiated (seeds) in WM or at the interface (between WM & GM). See [[(Review) Particle Filtering Tractography]] for an improvement on this tractography procedure.
Reinforcement learning #
RL agent goal: maximise the RL objective (sum of rewards) $$J = E_{s_t, a_t \sim \pi}[R(s_t, a_t)], \forall t \in 0..T$$
Learn the optimal policy \(\pi\) (maximizes J) $$\pi^* = argmax_{\pi} J, \forall s_t, \forall a_t \sim \pi(s_t), \forall t \in 0..T$$
Value function (to evaluate most desirable states) $$V^{\pi}(s_t) = \sum_t^{T} E_{s_t, a_t \sim \pi}[\gamma^{T-t} r_t]$$
Q-function (expected return by executing an action and following the policy for subsequent states): $$Q^{\pi}(s_t, a_t) = r_t + \gamma V^{\pi}(s_{t + 1})$$
Many algos proposed to solve the optimal policy: model-free and model-based algos.
- Model-free: Directly learn the policy to maximise the RL objective. (See [[Lecture 4, RL Intro.]])
- On-policy: Learn as they go, each new tuple is used to learn and then thrown away because we just updated the policy.
- Off-policy: Train policy uniformly based on samples collected during their interaction with the environment which are stored in a buffer.
- Model-based: Learn a model of the env and plan the optimal course of action using the model’s estimations.
- Model-free: Directly learn the policy to maximise the RL objective. (See [[Lecture 4, RL Intro.]])
Track-to-Learn #
- Inputs needed: signal volume (raw dMRI or the fODFs), tracking mask, peaks volume.
- State-space: signal volume. State is the signal at the tip of the tracked streamline.
- State also includes the signal from the six immediate neighbours (up, down, left, right, front, back), the WM mask value at the tip of the streamline and its neighbors.
- Finally, state also includes the last four streamline segments (for directional information)
- The agent action = orientation of the tracking step. This orientation is then scaled to the tracking step size to produce the segment \(u_t\). That new segment is used to obtain \(p_{t+1}\) and thus \(s_{t+1}\) (after appending additional information). $$u_t = \Delta \frac{a_t}{||a_t||}$$
- Reward (v() are the fODFs): $$r_t = |\max_{v(p_t)} \langle v(p_t), u_t \rangle| \cdot \langle u_t, u_{t-1} \rangle$$
- Episode: tracking of a streamline
- Stops if:
- Tracking out of tracking mask (i.e. WM)
- Angle too high on new orientation
- Reached max length (streamline too long)
- Cumulative angle between
- TTL uses 2 environments:
- Forward tracking
- Backwards tracking: starts from the end of the half-streamline which is retracked until the tracking arrives at the original seed, where it can then begin to track freely in the other direction. The retracking of the half-streamline doesn’t modify the streamline, but actions are still rewarded and used for training.
- Stops if:
Benchmarking & Experiments #
Datasets #
- Fibercup: Synthetic FiberCup dataset which is a recreation of the original physical FiberCup using Fiberfox tool.
- Challenging fiber configurations: kissings, fannings, crossings.
- 64x64x3 of 3mm isometric voxels, b-value of 1000 s/mm² over 30 gradient directions.
- No artifacts, SNR ~40.
- ISMRM2015: Synthetic virtual phantom mimicking human brain.
- 25 manually segmented bundles from subjects of the Human-Connectome Project dataset with a diffusion volume and a structural T1 image reconstructed using Fiberbox.
- 2mm isometric diffusion volume.
- 32 gradient directions, b-value of 1000 s/mm².
- Synthetic artifacts to mimic clinical data.
- TractoInferno: 284 manually quality-controlled in-vivo acquisitions with T1-weighted images and diffusion data. Data was processed through TractoFlow to denoise the DWI and resample the diffusion volume to 1mm isometric, segment the tissues and compute fODFs.
Evaluation metrics #
- VB: anatomically valid bundles.
- IB: anatomically invalid bundles.
- VC: Valid connection. Connecting two regions that should be connected.
- IC: Invalid connection. Individual streamlines anatomically invalid. (False positive!)
- NC: No-connections. Individual streamlines that end in white matter, which is impossible. A kind of false negative.
- TP/FP/FN: True positive, False positive, False negative
- OL: Overlap. La surface couverte par un bundle par rapport à la vérité terrain
- OR: Overreach. Fraction of voxels that have streamlines, but shouldn’t. (Ground truth doesn’t have streamlines in those voxels.)
Training #
All agents:
- 1000 episodes, no early stopping.
- Episode: tracking n=4096 streamlines.
- Grid search per algorithm, per dataset, per experiment over:
- Learning rate \(\eta\)
- Discount factor \(\gamma\)
- Algorithm specific parameters
- Prefer the best VC rate
- Five different random seeds
Tracking parameters #
- Fibercup: 10 seeds/voxel for training, 33 seeds/voxel for testing. Step size of 0.75mm.
- ISMRM2015: 2 seeds/voxel for training, 7 seeds/voxel for testing. Step size of 0.75mm.
- TractoInferno: 10 seeds/voxel on TractoInferno. Step size of 0.375mm.
- Keep streamlines within 20mm and 200mm.
Exp 1: Agents #
- Explore all RL agents, find best parameters and train each agent on 5 seeds.
- To assess the generalization capabilities of the algorithm, we use the agents trained on ISMRM2015 to track an in-vivo subject of the TractoInferno dataset.
- Baseline performance: compare trained agents to PFT algorithm (with same tracking parameters).
Exp 2: Seeding #
Using two environments for tracking with WM seeding might be a source of instability during training, since it creates a dependence between the backwards and forward trajectories (backwards trajectory depends on the first half-streamline generated), thus deviating from the usual reinforcement learning setting.
Exploring the impact of the seeding strategy on the training procedure and reconstructed tractograms (WM seeding vs WM/GM interface seeding).
- WM/GM interface has less voxels than WM typically so the number of seeds/voxel are adjusted to get roughly the same number of resulting streamlines:
- FiberCup: 100 seeds/voxel for training, 300 seeds/voxel for testing.
- ISMRM2015: 20 seeds/voxel for training, 60 seeds/voxel for testing.
- Tracking step size of 0.75mm
Exp 3: Retracking #
- WM seeding, but without the retracking regime. Backwards tracking begins at the input signal at the seed’s position and using the first four directions of the reversed half-streamlines from the forward phase.
- Same parameter search from experience 1.
Exp 4: State #
We explore here a few variations of the inputs concatenated in the state of the agent to asses their contribution in helping the agent reaching better tractograms. Recall previously what information the input state contained. We then explore and vary the following:
- Raw diffusion signal: Provide the raw diffusion signal instead of the precomputed fODFs that were fed to the agent within the state.
- Previous directions: use the same state, but vary the number of previous directions included from {0, 2, 4}.
- WM Mask: use the same state, but remove the tracking mask from the state to see whether the agents can still stay in the WM even if they’re not told exactly where it is.
Exp 5: Reward function #
Considering the reward function defined above, this reward function is designed to force the learning agents to reproduce the behaviour of classical tracking algorithms. We explore here additions to this reward and how it affects the resulting tractograms.
- Length bonus: Might lead to longer streamlines, thus fewer broken streamlines.
- \(\alpha_{length}\): hyper-parameter that controls amplitude of bonus.
- \(l_{max}\): max streamline length (in terms of number of pts since the step size is constant) $$r_t = |\max_{v(p_t)} \langle v(p_t), u_t \rangle| \cdot \langle u_t, u_{t-1} \rangle + \alpha_{length}(\frac{l(u)}{l_{max}})$$
- Reaching GM termination bonus: Try to encourage agents to avoid forming loops or broken streamlines.
- \(\alpha_{GM}\): hyper-parameter that controls amplitude of bonus.
- \(\mathbb{1}_{GM}\): 1 if \(u\) reached GM, 0 otherwise. (It’s the bonus)
$$r_t = |\max_{v(p_t)} \langle v(p_t), u_t \rangle| \cdot \langle u_t, u_{t-1} \rangle + \alpha_{GM} \mathbb{1}_{GM}(u)$$
Results #
For precise results, see the full article with the figures. The discussion below clearly analyses the results and offer a simpler and shorter review of what conclusions are taken from the results.
Discussion #
Variations on the components of Track-to-Learn #
Off-policy algorithms are best-suited for tractography: TD3, SAC & SAC Auto consistently achieve good/best results in all relevant experiments. DDPG, TD3, SAC & SAC Auto are more stable during training (lower variance in rewards obtained compared to on-policy).
Q-functions vs. Value functions: (Hypothesis) Taking an action during tracking is very sensitive: taking a bad step might cause the streamline to end prematurely (outside of mask || angle stopping criterion). So, there’s only a very small range of actions that lead to better states, thus meaning that the expectancy of the state-value will be very small in every state. The algorithm might not be able to differentiate better states. Probably prefer Q-functions.
Directionality is crucial for the agents: The agent cannot know where it came from, thus it can’t infer where it should go. The input signal (diffusion signal and WM mask) by itself doesn’t give any information about the directionality of steps taken. It’s problematic since tractography is a sequential and oriented process.
Seeding strategy is left to the user’s preference: Based on the results, interface seeding and WM seeding is equivalent.
- For bundle coverage: white matter seeding.
- Connection accuracy: interface seeding.
Simplified state formulation: The following modifications have low impact on the tractograms:
- Decrease the number of directions from 4 to 2.
- Remove the WM mask
- Replacing fODFs by raw diffusion data
Bonus for reaching target gray matter may be beneficial: Only improve in certain cases, be careful.
Reward vs performance #
In a several scenarios, a few algorithms achieve a good sum of cumulative rewards over the course of episodes, but actually have worst results in terms of VC and OL rates.
- This is because the agent is rewarded on a per step basis. (Reward for each step taken)
- Longer streamlines are keen to yielding a higher sum of cumulative reward.
- The reward doesn’t describe exactly what we are looking to model: the reward should be based on anatomical plausibility. Exploration on reward bonuses:
- Length bonus: no significant improvement since agents are kind of already rewarded based on the length of the streamline.
- GM bonus: only improvements when bonus is > 10 (because of sparsity of that bonus), and it sometimes confuses the critic, thus producing a higher loss.
Voxel size vs step size #
- Going from Fibercup to Flipped: good results since the voxel size is the same in both datasets.
- Going from ISMRM2015 to TractoInferno: terrible results since the voxel size isn’t the same (2 mm for ISMRM2015 and 1 mm for TractoInferno), whilst the step size is of both 0.75 mm for both.
- Required to half the step size of TractoInferno from 0.75 mm to 0.375 mm.
It’s not a problem, however, when the step size is expressed in the voxel space.
Comparison with Track-to-Learn #
- 80-150 variable number of episodes in the original paper, 1000 fixed number of episodes for this paper => agents have time to better fine-tune their policies.
- Simple grid search here to find best hyperparameters.
- Here using the mm step size instead of voxel step size.
- Smaller step size than in original work => might put the agents to a disadvantage in crossings and turns of the FiberCup dataset.
- Similar conclusions to the performance of RL methods for tractography despite different results.
Discount factor #
- Grid search selected relatively low discount factors: \(\gamma \in [0.5, 0.85]\) compared to the standard in the RL litterature: \(\gamma \in [0.9, 0.99]\).
- Contradicts what was proposed in the Track-to-Learn article where a higher discount was thought to lead to higher VC rates. Might be because of:
- \(\gamma = 0.85\) was the minimum tested in the previous article.
- Search limited by the grid search, instead of the Bayesian search from the previous article.
- Only TD3 algorithm was previously used when plotting the discount parameter vs different metrics.
- Lower discount factors might make the agents more greedy and less caring of the future reward, thus making loops.
Bundle coverage #
- Compared to PFT, the RL agents exhibit competitive or superior performance with respect to connection rates.
- RL agents tend to avoid the outside edge of the WM tracking mask which leads to much lower overlap (OL) rates.
- To get higher rewards, RL agents might avoid taking the risk of early termination due to tracking on the edge of what’s allowed.
- The last step produced by the RL agents (which exits the WM) is kept => increases the OR.
Future works #
- Using asymmetric ODFs (aODF): remove the need for the absolute value function in the reward function, peaks would now have a direction, thus helping agents learn and disentangle the tracking direction.
- Partial volume maps with or without a heuristic (i.e. PFT-like algorithm) to soften the termination criteria and/or integrated in the reward function to help agents fill the volume of the WM better. Or use surface normals to reward agents by providing a termination bonus based on the alignement of the last tracking segment and the normal.
- Use a von Mises-Fischer distribution instead of a normal distribution to sample actions from the agents (since the agent outputs a mean & variance).
- Try algorithms that produce actions in a discrete domain to leverage the Deep Q-Learning-based methods (which are popular and successful). Selecting directions on a discrete sphere instead of producing a 3D vector directly.
- Step size is constrained by the voxel size of the volume used during training. Maybe try and train the agent with several step sizes so they can generalize to the user’s preference, or by letting the agent decide of the step size.
- Improve the reward function to represent what is truly expected from the learning agents. We need to reward not for their alignement with their respective local peaks, but instead for anatomically plausible connections which would much better suit the tractography problem and prevent the agents from abusing the reward function.
Conclusion is at the top of the page noting the principal outcomes and recommendations from this paper. More details on the results and hyperparameter search results are included in the full article.