Think-at-Hard

Selective latent iterations to improve reasoning language models β€” iterate only on hard tokens, skip the easy ones.

+5.3%
Avg. Accuracy Gain
1.07Γ—
Avg. Iteration Depth
Looped
Transformer
DEPTH: 1.07Γ— SKIP: 93% MODE: SELECTIVE

Overview

Iterate smarter, not more

Looped transformers refine token predictions through multiple latent iterations β€” but iterating on every token wastes compute and risks latent overthinking, where correct predictions get flipped to errors.

Always-Iterate applies uniform iterations to all tokens β€” overthinking easy tokens while still underthinking hard ones
Think-at-Hard selectively iterates only on hard tokens β€” improving accuracy while reducing computation by 93%
TaH Overview - Latent iterations can fix wrong predictions but also overthink correct ones

The Discovery

Most tokens don't need to think twice.

We identify latent overthinking β€” where extra iterations hurt rather than help β€” and show that selective iteration unlocks significant untapped potential.

8.7%
Tokens Corrected
The second iteration fixes 8.7% of first-pass mispredictions, showing latent iterations can genuinely help on hard tokens.
2.1%
Tokens Overthought
But it also flips 2.1% of correct predictions into errors β€” a latent overthinking phenomenon that mirrors explicit CoT overthinking.
+7.3%
Oracle Potential
An oracle policy that iterates only on mispredicted tokens achieves up to 7.3% improvement β€” and up to 32% with TaH's optimized architecture.
Learn more about the oracle experiments

Oracle Iteration Policy

We establish an oracle policy Ο€ that triggers additional iterations only when the reference LLM mispredicts the target token at the first pass. Using top-1 mismatch as the discrepancy metric, this oracle iterates on only 12–19% of tokens.

Policy NTP AMC23 MMLU100 HE++
Always-1 (no iteration) 73.1 38.1 56.0 39.6
Always-2 (iterate all) 79.7 40.6 60.0 40.9
Oracle (selective) 81.8 +2.1 47.9 +7.3 62.0 +2.0 43.3 +2.4
Oracle w. TaH 89.3 +9.6 68.8 +28.2 85.0 +25.0 72.9 +32.0

Key insight: TaH's architecture better utilizes the oracle policy, achieving >25% improvement and surpassing even Qwen3-4B. The oracle requires ground-truth tokens unavailable at inference β€” TaH approximates it with a neural decider.

Accuracy landscape across iterations
Token prediction accuracy landscape

Because Ouro trains all iterations to predict all tokens, predictable tokens across depths largely overlap, leaving more tokens unpredictable by any iteration. TaH instead specializes deeper iterations for hard tokens, reducing overlap and improving coverage under selective iteration.

Architecture

TaH Design

Three architectural innovations enable efficient selective latent iteration.

TaH Architecture Overview

TaH Overview. (a) Regular causal attention. (b) Duo-causal attention extends causality to two dimensions. (c) TaH selectively iterates or verbalizes tokens using LoRA adapters and a neural iteration decider.

Duo-Causal Attention

2D causality across token positions and iteration depths β€” compatible with FlashAttention, no custom CUDA kernels needed

Depth-aware LoRA

LoRA adapters at $d > 1$ shift the objective from next-token prediction to hard-token refinement with <3% extra parameters

Neural Iteration Decider

Lightweight MLP that predicts which tokens need deeper thinking β€” trained to imitate the oracle policy in a stable two-stage scheme

1
Input & Forward Pass
2
Duo-Causal Attention
3
Depth Adapter (LoRA)
4
Iteration Decider

Standard Forward Pass (Depth 1)

Token embeddings enter the LLM backbone for a standard forward pass at depth $d=1$. The model uses its original pretrained weights $\theta$ without any LoRA adaptation. This first iteration produces standard next-token predictions β€” correct for ~93% of tokens.

Key: The first pass preserves the pretrained model's strong next-token prediction ability. Only hard tokens proceed to deeper iterations.

Duo-Causal Attention

Unlike standard causal attention (1D: attend to previous positions), duo-causal attention extends causality to two dimensions: tokens attend to both previous positions and shallower iteration depths. Formally: $X_{\le i}^{(\le d)} = \{x_j^{(k)} \mid j \le i, k \le d\}$.

Key: This enables cross-depth information flow β€” deeper tokens can access shallower representations of all previous tokens β€” while maintaining full parallel training via FlashAttention compatibility.

Depth Adapter (LoRA)

At deeper iterations ($d > 1$), LoRA adapters activate on top of the shared backbone: $\theta_d = \theta + \Delta$. This shifts the model's objective from general next-token prediction to focused hard-token refinement. Residual connections across iterations simplify the refinement process.

Key: LoRA adds less than 3% extra parameters while enabling the model to specialize deeper iterations. Without LoRA, the shared weights must handle both objectives, limiting performance.

Neural Iteration Decider

A lightweight MLP ($\mathcal{I}_\phi$) reads concatenated hidden states from shallow, middle, and final LLM layers to predict a continuation probability $\hat{c}_i^{(d)} \in [0,1]$. If $\hat{c}_i^{(d)} < c_{\text{threshold}}$, the token verbalizes; otherwise it continues to the next iteration depth.

Key: The decider is trained in Stage 2 to imitate the oracle policy via weighted binary cross-entropy. It achieves ~83% accuracy at predicting the oracle's decisions. Tokens like "But" (34%) and "So" (18%) are most frequently selected for deeper iteration.
Two-stage training scheme

The backbone LLM and iteration decider are tightly coupled β€” joint training is unstable. We decouple them with a two-stage approach under a fixed oracle policy $\pi$:

Stage 1: Backbone Training

Fine-tune the LLM ($\theta$) and LoRA ($\Delta$) with $\pi$-guided iteration. Loss = standard next-token prediction at the oracle-determined depth. This preserves first-iteration accuracy for easy tokens while training deeper iterations to refine hard ones.

Stage 2: Decider Training

Freeze the backbone and train the decider ($\phi$) to imitate $\pi$'s continuation decisions via weighted binary cross-entropy. Class weights handle the label imbalance between continue (~7%) and stop (~93%) decisions.

Experiments

Benchmark Results

Consistent gains across nine reasoning benchmarks at three model scales.

TaH+ Performance Summary
+5.3%
vs Standard
35.2%
Average Accuracy
MethodAIME25OlympiadAMC23MATH500GSM8KGPQAMMLUHE++MBPP++Average
Standard1.915.422.739.958.231.154.216.828.829.9
SoftThink2.914.022.239.655.924.753.014.329.528.5
Ouro2.114.219.737.456.635.454.018.923.529.1
AlwaysThink1.312.621.937.852.630.851.49.113.825.7
TaH2.119.124.146.263.629.056.421.633.932.9
TaH+4.620.624.751.867.631.359.022.035.135.2
Real-world efficiency on A800 GPU

1.7B model, AIME25, 8K max token length on a single NVIDIA A800 GPU:

StandardAlwaysThinkTaH
Avg. Depth1.002.001.06
Memory (GB)4.36.84.6
Latency (s)210.6747.2301.4
Throughput (tok/s)38.911.027.2

TaH iterates twice on only 6% of tokens, achieving 1.48Γ— lower memory overhead and 2.48Γ— faster decoding than AlwaysThink.

Contributions

Key Contributions

Selective Latent Iteration

First to identify latent overthinking in looped transformers and propose selective iteration as a new design principle β€” iterate only on hard tokens for better quality and efficiency.

Specialized Architecture

Duo-causal attention, depth-aware LoRA, and neural iteration decider β€” three components that natively support selective iteration with full training parallelism.

Efficient & Effective

+5.3–6.2% accuracy gains across nine benchmarks with <3% extra parameters and only 1.07Γ— average iteration depth β€” 93% of tokens need just one pass.

Team

Meet the researchers.

1Tsinghua University    2Infinigence AI    3Shanghai Jiao Tong University
*Equal contribution

Citation

Cite our work.

If you find TaH useful for your research, please consider citing our paper.

BibTeX
@article{fu2025think,
  title={Think-at-Hard: Selective Latent Iterations to Improve Reasoning Language Models},
  author={Fu, Tianyu and You, Yichen and Chen, Zekai and Dai, Guohao and Yang, Huazhong and Wang, Yu},
  journal={arXiv preprint arXiv:2511.08577},
  year={2025}
}