I’m trying to reproduce the results from the SANDWiCH word sense disambiguation paper. To do this I’m fine tuning a DebertaV2ForSequenceClassification model with microsoft/deberta-v3-small as the base model, and the same training parameters as given in the paper.
However, I keep seeing numerical stability issues. As the charts below show, at some point in the training there is a huge spike in the loss, after which the gradient norm becomes NaN. What can I do to diagnose and resolve this?
I ran some experiments in the Colab environment and got some results.
The most likely cause is not one single bug. It is a combination of three things:
-
The paper’s training description is internally ambiguous. In the main text, SANDWiCH says the cross-encoder uses DeBERTa-v3-small with batch size 64, 10 epochs, LR
2e-5, gradient clipping at 1, cosine annealing, and binary cross-entropy with logits. But Appendix Table 7 lists AdamW, LR candidates{1e-7, 1e-6, 1e-5, 1e-4}, max tokens 512, and Cross-Entropy with logits. The released repo also says the public code differs slightly from the paper version. So “I used the same parameters as the paper” can still hide a loss mismatch. -
DebertaV2ForSequenceClassificationdoes not default to the paper’s likely loss. In the Hugging Face docs, whennum_labels == 1, the built-in loss is regression / MSE. Whennum_labels > 1, it is CrossEntropy. Multi-label behavior is a separate path usingproblem_type="multi_label_classification", and the example uses float labels. So if you intended a 1-logit BCE objective but let the model choose its own loss, you may not actually be training what you think you are training. (Hugging Face) -
If you are using FP16, one overflow step can poison the run. Hugging Face’s Trainer docs say BF16 is generally preferred over FP16 because it is more numerically stable and does not require loss scaling. PyTorch’s AMP docs also warn that AMP/FP16 may not work for every model, and that the
GradScalercan reduce its scale below 1 when gradients overflow. There is public precedent for DeBERTa-family precision trouble as well: Microsoft has an mDeBERTa issue saying FP32 was needed there, and there is a public Transformers issue where DeBERTa training produced NaN gradients and infinite grad norm. (Hugging Face)
What your chart usually means
A sharp loss spike followed by grad_norm = NaN usually means one update went non-finite, then everything after it is contaminated. Gradient clipping helps only if gradients are still finite. Once they are already inf or nan, clipping is no longer the main safeguard.
The highest-probability failure modes in your setup
1. Loss mismatch
This is the first thing I would check.
There are two clean interpretations of the paper:
- Main-text interpretation: one relevance logit, trained with
BCEWithLogitsLoss - Appendix interpretation: two logits, trained with
CrossEntropyLoss
What you do not want is a hybrid such as:
num_labels=1with built-in loss and integer labels- one logit with
sigmoidfollowed byBCELoss num_labels=2but still thinking in BCE semantics
PyTorch’s docs are explicit that BCEWithLogitsLoss is more numerically stable than sigmoid followed by BCE because it uses a stabilized formulation. (PyTorch Docs)
2. Label dtype / shape mismatch
This is easy to miss with DeBERTa.
A public Transformers issue shows that when the model ends up on the MSE path and labels are integers, CUDA can fail with mse_cuda not implemented for 'Long'; the reporter concluded the fix was to use float labels. That is not your exact error, but it is strong evidence that DeBERTa sequence-classification loss routing and label dtype matter here. (GitHub)
For BCE-style training, labels should be float values in {0.0, 1.0}. For CE-style training with two logits, labels should be integer class IDs.
3. FP16 overflow
If you are on fp16=True, this becomes a prime suspect.
The Trainer docs say BF16 is preferred when supported. PyTorch AMP also says FP16 may not work for every model. If the run becomes stable in FP32 or BF16 but not FP16, that is your answer. (Hugging Face)
What I would do, in order
Step 1. Stop relying on the model’s built-in loss
Make the objective explicit.
If you want the paper main-text behavior
Use one logit and compute BCE manually with the stable loss:
outputs = model(
input_ids=batch["input_ids"],
attention_mask=batch["attention_mask"],
)
logits = outputs.logits.squeeze(-1) # shape [B]
labels = batch["labels"].float().view(-1) # 0.0 / 1.0
loss = torch.nn.functional.binary_cross_entropy_with_logits(logits, labels)
This matches the paper’s main-text wording and PyTorch’s stable BCE recommendation.
If you want the appendix behavior
Use two logits and CE:
model = AutoModelForSequenceClassification.from_pretrained(
"microsoft/deberta-v3-small",
num_labels=2,
)
loss = torch.nn.functional.cross_entropy(logits, labels.long())
The important part is to choose one interpretation and stay consistent.
Step 2. Run a short debug job in pure FP32
No FP16. No BF16. Same data. Same seed. Same optimizer. Same loss.
Interpretation:
- Stable in FP32, unstable in FP16 → precision problem
- Unstable even in FP32 → loss, labels, data, or optimizer wiring problem
That split saves time.
Step 3. If mixed precision is needed, prefer BF16 over FP16
If your hardware supports BF16, switch to that before trying anything more exotic. Hugging Face explicitly recommends it over FP16 for stability. (Hugging Face)
Step 4. Verify effective batch size
The paper says batch size 64, but your true optimization batch may be larger if you use gradient accumulation or multiple GPUs. Hugging Face’s Trainer logs total train batch size including parallelism and accumulation, and documents that gradient accumulation changes how often optimizer updates occur. (Hugging Face)
If your effective batch is larger than intended, the paper LR may be too aggressive in your setup.
Step 5. Add anomaly detection and catch the first bad batch
PyTorch has a built-in tool for exactly this:
torch.autograd.set_detect_anomaly(True)
The docs say set_detect_anomaly(..., check_nan=True) raises when backward generates NaNs. Use it for a short run only, because it is slower. (PyTorch Docs)
Also log, for every batch:
labels.dtype,labels.shape, unique label valueslogits.abs().max()torch.isfinite(loss)torch.isfinite(logits).all()- dataset indices for the batch
That tells you whether the first non-finite value came from data, logits, or the optimizer step.
Step 6. If you use custom AMP, unscale before clipping
PyTorch’s AMP examples are explicit: call scaler.unscale_(optimizer) before clip_grad_norm_. Otherwise you clip scaled gradients, which makes the threshold meaningless. (PyTorch Docs)
Step 7. Lower the LR one notch and add warmup
Given the ambiguity in the paper, I would not treat 2e-5 as sacred. A good stabilization test is:
- try
1e-5 - use a short warmup
- keep
max_grad_norm=1.0 - keep weight decay modest
This is an engineering recommendation, not a claim that the paper used warmup.
Step 8. Temporarily reduce max length
Appendix Table 7 says max tokens 512. Long sentence-definition pairs make extreme batches more likely. For diagnosis, cut to 256 and see whether the spike disappears. If it does, sequence length is part of the trigger.
Things that are easy to misread
logging_nan_inf_filter=True does not fix training
The Trainer docs say this only filters NaN/Inf for logging. It does not change gradients or optimizer behavior. So if your dashboard looks calmer than the raw run, that setting may be hiding the symptom, not solving it. (Hugging Face)
A DeBERTa checkpoint loading “missing/unexpected keys” is normal
If you load a pretrained backbone into a sequence-classification head, new classifier weights are initialized and task-specific pretraining heads are dropped. That is normal. It is not the cause of NaN spikes.
My best diagnosis for your case
Most likely ranking:
- You are not training the exact loss you think you are training.
- FP16 is causing one overflow step.
- The paper LR / batch / 512-token regime is acting as the trigger.
Minimal stable baseline I would start from
microsoft/deberta-v3-smallnum_labels=1- explicit
binary_cross_entropy_with_logits - labels cast to
float32 - FP32 for debugging
- LR
1e-5 - cosine scheduler
max_grad_norm=1.0- shorter max length first
- anomaly detection on for a short repro
After that is stable:
- restore 512 tokens
- switch to BF16 if supported
- raise LR if needed
- restore your full batch / accumulation setup
That is the fastest path to separating “wrong objective” from “precision overflow.”
Below is a single-file demo. It uses microsoft/deberta-v3-xsmall by default because it is smaller than deberta-v3-small, and SetFit/sst2 because it is a simple public binary dataset on the Hub. DeBERTa’s built-in sequence-classification loss is a trap here because num_labels == 1 defaults to regression, not BCE. PyTorch also documents that BCEWithLogitsLoss is the stable form and safe to autocast, unlike hand-written sigmoid + BCE formulas. (Hugging Face)
# deberta_repro_vs_fix_demo.py
#
# deps:
# pip install -U "torch>=2.2" "transformers>=4.46" "datasets>=2.18" "sentencepiece>=0.1.99"
#
# What this script shows
# ----------------------
# 1) BAD_REPRO
# Intentionally unstable binary loss:
# - computes sigmoid(logits) first
# - then does manual log-based BCE
# - optionally uses fp16 on CUDA
# - no clipping, no warmup
# This often creates inf/nan quickly.
#
# 2) GOOD_FIX
# Safer setup:
# - explicit binary_cross_entropy_with_logits(...)
# - float32 on CPU
# - bf16 on CUDA if supported, otherwise fp32
# - smaller LR
# - warmup + cosine decay
# - grad clipping
#
# URLs for context
# ----------------
# SANDWiCH paper:
# https://arxiv.org/pdf/2503.05958
#
# HF DeBERTa docs:
# https://huggingface.co/docs/transformers/en/model_doc/deberta-v2
#
# HF DeBERTa source showing the num_labels==1 regression path:
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/deberta_v2/modeling_deberta_v2.py
#
# PyTorch BCEWithLogitsLoss docs:
# https://docs.pytorch.org/docs/stable/generated/torch.nn.BCEWithLogitsLoss.html
#
# PyTorch AMP docs:
# https://docs.pytorch.org/docs/stable/amp.html
#
# Model card:
# https://huggingface.co/microsoft/deberta-v3-xsmall
#
# Dataset:
# https://huggingface.co/datasets/SetFit/sst2
#
# Notes
# -----
# - This is a demo. BAD_REPRO is intentionally wrong.
# - GOOD_FIX is the part to copy into a real training setup.
# - For a closer match to your real run, change MODEL_NAME to:
# "microsoft/deberta-v3-small"
# if your GPU RAM allows it.
# - On CPU this stays in float32.
# - On CUDA it prefers bf16 if available. That is safer than fp16.
import math
import os
import random
import time
from contextlib import nullcontext
import torch
import torch.nn.functional as F
from datasets import load_dataset
from torch.utils.data import DataLoader
from transformers import AutoModelForSequenceClassification, AutoTokenizer
# -------------------------
# user-editable settings
# -------------------------
SEED = 1234
# Smaller default model for 10-16 GB class hardware.
MODEL_NAME = "microsoft/deberta-v3-xsmall"
# MODEL_NAME = "microsoft/deberta-v3-small" # closer to your real case, heavier
DATASET_NAME = "SetFit/sst2"
# Keep the run small enough for CPU or modest GPU.
TRAIN_ROWS = 384
VAL_ROWS = 128
MAX_LENGTH = 128
# Physical batch size. Small on purpose for low VRAM/RAM.
BATCH_SIZE = 4
# Two short runs.
BAD_STEPS = 20
GOOD_STEPS = 40
# Bad run. Intentionally unstable.
BAD_LR = 3e-4
BAD_LOGIT_SCALE = 64.0 # increase if BAD_REPRO stays finite on your machine
BAD_USE_FP16_IF_CUDA = True
# Good run. Safer defaults.
GOOD_LR = 1e-5
GOOD_WEIGHT_DECAY = 0.01
GOOD_ADAM_EPS = 1e-6
GOOD_CLIP_NORM = 1.0
GOOD_WARMUP_RATIO = 0.10
# Optional debugging.
USE_ANOMALY_DETECTION = False
os.environ["TOKENIZERS_PARALLELISM"] = "false"
# Optional performance hint.
try:
torch.set_float32_matmul_precision("high")
except Exception:
pass
def set_seed(seed: int) -> None:
random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
def detect_runtime():
"""Choose a safe precision policy."""
if not torch.cuda.is_available():
return {
"device": torch.device("cpu"),
"bad_precision": "fp32",
"good_precision": "fp32",
"gpu_name": None,
"vram_gb": None,
}
props = torch.cuda.get_device_properties(0)
try:
bf16_ok = torch.cuda.is_bf16_supported()
except Exception:
bf16_ok = False
return {
"device": torch.device("cuda"),
"bad_precision": "fp16" if BAD_USE_FP16_IF_CUDA else "fp32",
"good_precision": "bf16" if bf16_ok else "fp32",
"gpu_name": props.name,
"vram_gb": round(props.total_memory / (1024 ** 3), 2),
}
def autocast_ctx(device, precision: str):
"""Autocast only on CUDA."""
if device.type != "cuda":
return nullcontext()
if precision == "fp16":
return torch.autocast(device_type="cuda", dtype=torch.float16)
if precision == "bf16":
return torch.autocast(device_type="cuda", dtype=torch.bfloat16)
return nullcontext()
def cosine_lr(step: int, total_steps: int, warmup_steps: int, base_lr: float) -> float:
"""Simple warmup + cosine decay."""
if step <= warmup_steps:
return base_lr * float(step) / float(max(1, warmup_steps))
progress = float(step - warmup_steps) / float(max(1, total_steps - warmup_steps))
return 0.5 * base_lr * (1.0 + math.cos(math.pi * progress))
def grad_global_norm(model) -> float:
"""Return NaN if any grad is non-finite."""
total_sq = 0.0
saw_grad = False
for p in model.parameters():
if p.grad is None:
continue
saw_grad = True
g = p.grad.detach()
if not torch.isfinite(g).all():
return float("nan")
total_sq += float(g.float().pow(2).sum().item())
if not saw_grad:
return 0.0
return total_sq ** 0.5
def any_nonfinite_params(model) -> bool:
for p in model.parameters():
if not torch.isfinite(p.detach()).all():
return True
return False
def move_batch(batch, device):
return {k: v.to(device) for k, v in batch.items()}
def endless(loader):
while True:
for batch in loader:
yield batch
def load_small_data(tokenizer):
"""
Simple public dataset with train/validation/test splits.
We use only a tiny subset for a fast demo.
"""
ds = load_dataset(DATASET_NAME)
train_ds = ds["train"].shuffle(seed=SEED).select(range(min(TRAIN_ROWS, len(ds["train"]))))
val_ds = ds["validation"].shuffle(seed=SEED).select(range(min(VAL_ROWS, len(ds["validation"]))))
def tokenize_fn(batch):
return tokenizer(
batch["text"],
truncation=True,
max_length=MAX_LENGTH,
)
train_ds = train_ds.map(tokenize_fn, batched=True)
val_ds = val_ds.map(tokenize_fn, batched=True)
keep_cols = ["input_ids", "attention_mask", "label"]
train_ds = train_ds.remove_columns([c for c in train_ds.column_names if c not in keep_cols])
val_ds = val_ds.remove_columns([c for c in val_ds.column_names if c not in keep_cols])
return train_ds, val_ds
def make_collator(tokenizer):
def collate(examples):
features = [{"input_ids": x["input_ids"], "attention_mask": x["attention_mask"]} for x in examples]
batch = tokenizer.pad(features, padding=True, return_tensors="pt")
# Float labels because GOOD_FIX uses BCE-with-logits on one raw logit.
batch["labels"] = torch.tensor([x["label"] for x in examples], dtype=torch.float32).unsqueeze(-1)
return batch
return collate
def evaluate(model, loader, device, precision):
model.eval()
total = 0
correct = 0
mean_prob = 0.0
num_batches = 0
with torch.no_grad():
for batch in loader:
batch = move_batch(batch, device)
labels = batch.pop("labels")
with autocast_ctx(device, precision):
logits = model(**batch).logits
probs = torch.sigmoid(logits.float())
preds = (probs >= 0.5).float()
correct += int((preds == labels).sum().item())
total += labels.numel()
mean_prob += float(probs.mean().item())
num_batches += 1
return {
"acc": correct / max(total, 1),
"mean_prob": mean_prob / max(num_batches, 1),
}
def bad_loss(logits, labels):
"""
Intentionally unstable anti-pattern.
Do NOT use this in real training.
It is here only to reproduce non-finite behavior.
Why it is bad:
- applies sigmoid first
- manual log(...) terms can hit log(0)
- scaled logits make saturation happen sooner
"""
scaled_logits = logits * BAD_LOGIT_SCALE
probs = torch.sigmoid(scaled_logits)
loss = -(labels * torch.log(probs) + (1.0 - labels) * torch.log(1.0 - probs)).mean()
return loss, scaled_logits, probs
def good_loss(logits, labels):
"""
Correct stable path.
Feed raw logits directly into BCE-with-logits.
Do NOT apply sigmoid before this.
"""
return F.binary_cross_entropy_with_logits(logits, labels)
def run_once(
name: str,
model_name: str,
device,
precision: str,
train_loader,
val_loader,
use_good_loss: bool,
steps: int,
lr: float,
weight_decay: float = 0.0,
adam_eps: float = 1e-8,
clip_norm: float | None = None,
):
print(f"\n=== RUN: {name} ===")
print(
f"model={model_name} precision={precision} steps={steps} "
f"lr={lr} wd={weight_decay} eps={adam_eps} clip_norm={clip_norm}"
)
# Important:
# num_labels=1 by itself does NOT mean BCE in DeBERTa.
# We set num_labels=1 only because we are computing the loss ourselves.
model = AutoModelForSequenceClassification.from_pretrained(
model_name,
num_labels=1,
).to(device)
optimizer = torch.optim.AdamW(
model.parameters(),
lr=lr,
weight_decay=weight_decay,
eps=adam_eps,
)
scaler = torch.amp.GradScaler("cuda", enabled=(device.type == "cuda" and precision == "fp16"))
if USE_ANOMALY_DETECTION:
torch.autograd.set_detect_anomaly(True)
warmup_steps = int(GOOD_WARMUP_RATIO * steps) if use_good_loss else 0
train_iter = endless(train_loader)
summary = {
"name": name,
"steps_done": 0,
"stopped_early": False,
"fail_reason": None,
"best_val_acc": 0.0,
}
start_time = time.time()
for step in range(1, steps + 1):
model.train()
batch = move_batch(next(train_iter), device)
labels = batch.pop("labels")
optimizer.zero_grad(set_to_none=True)
try:
with autocast_ctx(device, precision):
logits = model(**batch).logits
if use_good_loss:
loss = good_loss(logits, labels)
tracked_logits = logits.detach().float()
tracked_probs = torch.sigmoid(logits.detach().float())
else:
loss, tracked_logits, tracked_probs = bad_loss(logits, labels)
if scaler.is_enabled():
scaler.scale(loss).backward()
# Unscale before clipping. This matters in AMP.
scaler.unscale_(optimizer)
else:
loss.backward()
if clip_norm is not None:
torch.nn.utils.clip_grad_norm_(model.parameters(), clip_norm)
grad_norm = grad_global_norm(model)
if scaler.is_enabled():
scaler.step(optimizer)
scaler.update()
else:
optimizer.step()
# Simple scheduler only on the good path.
if use_good_loss:
new_lr = cosine_lr(step, steps, warmup_steps, lr)
for group in optimizer.param_groups:
group["lr"] = new_lr
loss_value = float(loss.detach().float().item())
max_abs_logit = float(tracked_logits.abs().max().item())
mean_prob = float(tracked_probs.mean().item())
summary["steps_done"] = step
print(
f"[{name}] step={step:03d} "
f"loss={loss_value:>10.6f} "
f"grad_norm={grad_norm:>10.6f} "
f"max|logit|={max_abs_logit:>8.4f} "
f"mean_prob={mean_prob:>7.4f} "
f"lr={optimizer.param_groups[0]['lr']:.2e}"
)
# Stop as soon as anything goes non-finite.
if (
not math.isfinite(loss_value)
or not math.isfinite(grad_norm)
or any_nonfinite_params(model)
):
summary["stopped_early"] = True
summary["fail_reason"] = "non-finite loss/grad/param detected"
print(f"[{name}] STOP: {summary['fail_reason']}")
break
# Tiny eval every 10 steps.
if step % 10 == 0 or step == steps:
metrics = evaluate(model, val_loader, device, precision)
summary["best_val_acc"] = max(summary["best_val_acc"], metrics["acc"])
print(
f"[{name}] eval step={step:03d} "
f"val_acc={metrics['acc']:.4f} "
f"val_mean_prob={metrics['mean_prob']:.4f}"
)
except Exception as e:
summary["stopped_early"] = True
summary["fail_reason"] = f"{type(e).__name__}: {str(e)[:300]}"
print(f"[{name}] EXCEPTION: {summary['fail_reason']}")
break
elapsed = time.time() - start_time
summary["elapsed_sec"] = round(elapsed, 2)
del model
del optimizer
if torch.cuda.is_available():
torch.cuda.empty_cache()
return summary
def main():
set_seed(SEED)
rt = detect_runtime()
device = rt["device"]
print("\n=== ENV ===")
print(f"model_name={MODEL_NAME}")
print(f"dataset_name={DATASET_NAME}")
print(f"device={device}")
print(f"bad_precision={rt['bad_precision']}")
print(f"good_precision={rt['good_precision']}")
if rt["gpu_name"] is not None:
print(f"gpu={rt['gpu_name']}")
print(f"vram_gb={rt['vram_gb']}")
print(
"\nReminder: in DeBERTa sequence classification, num_labels=1 with the built-in "
"loss path is regression by default. This demo avoids that trap by computing "
"the binary loss explicitly."
)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True)
train_ds, val_ds = load_small_data(tokenizer)
collate = make_collator(tokenizer)
# Smaller physical batch on CPU.
bs = BATCH_SIZE if device.type == "cuda" else 2
train_loader = DataLoader(
train_ds,
batch_size=bs,
shuffle=True,
num_workers=0,
pin_memory=(device.type == "cuda"),
collate_fn=collate,
)
val_loader = DataLoader(
val_ds,
batch_size=max(2, bs),
shuffle=False,
num_workers=0,
pin_memory=(device.type == "cuda"),
collate_fn=collate,
)
bad_summary = run_once(
name="BAD_REPRO",
model_name=MODEL_NAME,
device=device,
precision=rt["bad_precision"],
train_loader=train_loader,
val_loader=val_loader,
use_good_loss=False,
steps=BAD_STEPS,
lr=BAD_LR,
weight_decay=0.0,
adam_eps=1e-8,
clip_norm=None,
)
good_summary = run_once(
name="GOOD_FIX",
model_name=MODEL_NAME,
device=device,
precision=rt["good_precision"],
train_loader=train_loader,
val_loader=val_loader,
use_good_loss=True,
steps=GOOD_STEPS,
lr=GOOD_LR,
weight_decay=GOOD_WEIGHT_DECAY,
adam_eps=GOOD_ADAM_EPS,
clip_norm=GOOD_CLIP_NORM,
)
print("\n=== SUMMARY ===")
print(bad_summary)
print(good_summary)
print("\n=== HOW TO READ THIS ===")
print("- BAD_REPRO is supposed to be fragile.")
print("- GOOD_FIX is supposed to stay finite.")
print("- If BAD_REPRO stays finite on your machine, raise BAD_LOGIT_SCALE or BAD_LR.")
print("- If GOOD_FIX still goes non-finite on CUDA, force fp32 and then try bf16.")
print("- For a closer match to your real case, switch MODEL_NAME to deberta-v3-small.")
if __name__ == "__main__":
main()
"""
=== ENV ===
model_name=microsoft/deberta-v3-xsmall
dataset_name=SetFit/sst2
device=cpu
bad_precision=fp32
good_precision=fp32
...
=== SUMMARY ===
{'name': 'BAD_REPRO', 'steps_done': 1, 'stopped_early': True, 'fail_reason': 'non-finite loss/grad/param detected', 'best_val_acc': 0.0, 'elapsed_sec': 20.62}
{'name': 'GOOD_FIX', 'steps_done': 40, 'stopped_early': False, 'fail_reason': None, 'best_val_acc': 0.5390625, 'elapsed_sec': 820.93}
=== HOW TO READ THIS ===
- BAD_REPRO is supposed to be fragile.
- GOOD_FIX is supposed to stay finite.
- If BAD_REPRO stays finite on your machine, raise BAD_LOGIT_SCALE or BAD_LR.
- If GOOD_FIX still goes non-finite on CUDA, force fp32 and then try bf16.
- For a closer match to your real case, switch MODEL_NAME to deberta-v3-small.
"""
Why these choices:
deberta-v3-xsmallis a lighter default. The model card says it has a 22M backbone, while the same family’s small variant is larger, soxsmallis the safer low-VRAM demo baseline. (Hugging Face)SetFit/sst2is a public binary dataset on the Hub with ready-made splits, so it is simple for a demo and avoids a custom dataset script in your code. (Hugging Face)- The “bad” path is intentionally wrong because the point is to reproduce the failure shape. The “good” path is the part that matches the stable recommendations from the PyTorch and Transformers docs. (PyTorch Docs)
Thanks for the detailed answer.
