DEA-Net: Single image dehazing based on detail-enhanced convolution and content-guided attention
Paper β’ 2301.04805 β’ Published
YAML Metadata Warning:empty or missing yaml metadata in repo card
Check out the documentation for more information.
Lightweight single-image/video dehazing model designed for real-time inference (30fps) at 512Γ512 on NVIDIA Jetson Orin.
| Spec | Value |
|---|---|
| Parameters | 996,023 (~1M) |
| GMACs | 6.55 @512Γ512 |
| Architecture | 5-stage gated U-Net with SK Fusion |
| Target | β₯30 FPS on Jetson Orin (TensorRT FP16) |
| Dataset | Haze4K (3000 train / 1000 test pairs) |
Based on two landmark papers:
Encoder: 3 β 24 β 48 β 96 β 192 (stride-2 conv downsample)
Bottleneck: 192 (4 gConv blocks)
Decoder: 192 β 96 β 48 β 24 β 3 (pixel shuffle upsample + SK Fusion)
# Install dependencies
pip install torch torchvision numpy Pillow huggingface_hub trackio onnx onnxscript
# Train (downloads Haze4K automatically from HF Hub)
python train.py
# Or via HF Jobs (GPU):
# Hardware: a10g-large, Timeout: 6h
# Env: HUB_MODEL_ID=your-username/micro-dehaze-net
import torch
from model import MicroDehazeNet
model = MicroDehazeNet(base_channels=24, num_blocks=2)
state_dict = torch.load('pytorch_model.pth', map_location='cpu')
model.load_state_dict(state_dict)
model.eval()
# Input: [B, 3, H, W] in [0, 1] range
hazy = torch.randn(1, 3, 512, 512)
with torch.no_grad():
dehazed = model(hazy).clamp(0, 1)
# Step 1: Convert ONNX β TensorRT engine (on Jetson)
/usr/src/tensorrt/bin/trtexec \
--onnx=micro_dehaze_net.onnx \
--saveEngine=micro_dehaze_net_fp16.engine \
--fp16 --workspace=1024
# Step 2: Run real-time video dehazing
python jetson_inference.py --engine micro_dehaze_net_fp16.engine --source 0
# Or INT8 for maximum speed:
/usr/src/tensorrt/bin/trtexec \
--onnx=micro_dehaze_net.onnx \
--saveEngine=micro_dehaze_net_int8.engine \
--int8 --workspace=1024
python jetson_inference.py --model pytorch_model.pth --benchmark
python jetson_inference.py --engine micro_dehaze_net_fp16.engine --benchmark
| Hyperparameter | Value | Source |
|---|---|---|
| Loss | Charbonnier (Ξ΅=1e-3) | gUNet, DEA-Net |
| Optimizer | AdamW (lr=2e-4, Ξ²=(0.9, 0.999)) | gUNet |
| LR Schedule | Cosine annealing (2e-4 β 1e-6) with 10-epoch warmup | gUNet |
| Batch size | 16 | β |
| Patch size | 256Γ256 | Standard in all dehazing papers |
| Augmentation | Random crop + H/V flip + 90Β° rotation | Standard |
| Mixed precision | FP16 (autocast + GradScaler) | β |
| BN freezing | Last 20% of training | gUNet paper |
| Gradient clip | 1.0 | β |
| Epochs | 500 | β |
| Platform | Precision | Est. Latency @512Γ512 | Est. FPS |
|---|---|---|---|
| Jetson Orin NX (16GB) | FP16 TensorRT | ~15-20ms | 50-65 |
| Jetson Orin NX (16GB) | INT8 TensorRT | ~8-12ms | 80-125 |
| Jetson Orin Nano (8GB) | FP16 TensorRT | ~25-30ms | 33-40 |
| A10G (24GB) | FP32 PyTorch | ~3-5ms | 200+ |
| RTX 3090 | FP32 PyTorch | ~3-4ms | 250+ |
Estimates based on 6.55 GMACs model size and published benchmarks from gUNet paper (3.39ms @256Γ256 on RTX3090 for similar architecture). Actual latency depends on TensorRT version and Jetson thermal state.
| File | Description |
|---|---|
model.py |
Model architecture (MicroDehazeNet) |
train.py |
Full training script with data download |
jetson_inference.py |
Real-time inference for Jetson (TensorRT + PyTorch) |
pytorch_model.pth |
Trained weights (after training) |
pytorch_model_fused.pth |
BN-fused weights for inference |
micro_dehaze_net.onnx |
ONNX export for TensorRT conversion |
config.json |
Model config and training metrics |
Input β BatchNorm β PWConv(expand 2Γ) β DWConv(k=5) β SimpleGate β PWConv(project) β + Input
skip + decoder β GAP β FC β ReLU β FC β Split β Softmax β weighted sum
Adaptively weights encoder skip features vs decoder features.
MIT