ONNX

YAML Metadata Warning:empty or missing yaml metadata in repo card

Check out the documentation for more information.

MicroDehazeNet β€” Real-Time Video Dehazing for Jetson

Overview

Lightweight single-image/video dehazing model designed for real-time inference (30fps) at 512Γ—512 on NVIDIA Jetson Orin.

Spec Value
Parameters 996,023 (~1M)
GMACs 6.55 @512Γ—512
Architecture 5-stage gated U-Net with SK Fusion
Target β‰₯30 FPS on Jetson Orin (TensorRT FP16)
Dataset Haze4K (3000 train / 1000 test pairs)

Architecture Design

Based on two landmark papers:

  • gUNet-T (arxiv:2209.11448) β€” gated depthwise convolutions, SK Fusion skip connections, BatchNorm
  • NAFNet (arxiv:2204.04676) β€” SimpleGate (element-wise product, no Sigmoid) for INT8-friendly inference

Key Design Choices for Jetson

  1. SimpleGate (XβŠ™Y) instead of Sigmoid gating β€” pure multiply, no lookup tables, better INT8 quantization
  2. DWConv k=5 β€” larger receptive field than 3Γ—3 with minimal compute overhead
  3. BatchNorm (not LayerNorm) β€” fuses into Conv2d at inference, zero runtime cost
  4. Pixel Shuffle upsampling β€” no deconv checkerboard artifacts
  5. Global residual β€” output = input + model(input), stabilizes training
Encoder:     3 β†’ 24 β†’ 48 β†’ 96 β†’ 192   (stride-2 conv downsample)
Bottleneck:  192 (4 gConv blocks)
Decoder:     192 β†’ 96 β†’ 48 β†’ 24 β†’ 3    (pixel shuffle upsample + SK Fusion)

Quick Start

Training

# Install dependencies
pip install torch torchvision numpy Pillow huggingface_hub trackio onnx onnxscript

# Train (downloads Haze4K automatically from HF Hub)
python train.py

# Or via HF Jobs (GPU):
# Hardware: a10g-large, Timeout: 6h
# Env: HUB_MODEL_ID=your-username/micro-dehaze-net

Inference (PyTorch)

import torch
from model import MicroDehazeNet

model = MicroDehazeNet(base_channels=24, num_blocks=2)
state_dict = torch.load('pytorch_model.pth', map_location='cpu')
model.load_state_dict(state_dict)
model.eval()

# Input: [B, 3, H, W] in [0, 1] range
hazy = torch.randn(1, 3, 512, 512)
with torch.no_grad():
    dehazed = model(hazy).clamp(0, 1)

Jetson Deployment (TensorRT)

# Step 1: Convert ONNX β†’ TensorRT engine (on Jetson)
/usr/src/tensorrt/bin/trtexec \
    --onnx=micro_dehaze_net.onnx \
    --saveEngine=micro_dehaze_net_fp16.engine \
    --fp16 --workspace=1024

# Step 2: Run real-time video dehazing
python jetson_inference.py --engine micro_dehaze_net_fp16.engine --source 0

# Or INT8 for maximum speed:
/usr/src/tensorrt/bin/trtexec \
    --onnx=micro_dehaze_net.onnx \
    --saveEngine=micro_dehaze_net_int8.engine \
    --int8 --workspace=1024

Benchmark

python jetson_inference.py --model pytorch_model.pth --benchmark
python jetson_inference.py --engine micro_dehaze_net_fp16.engine --benchmark

Training Recipe

Hyperparameter Value Source
Loss Charbonnier (Ξ΅=1e-3) gUNet, DEA-Net
Optimizer AdamW (lr=2e-4, Ξ²=(0.9, 0.999)) gUNet
LR Schedule Cosine annealing (2e-4 β†’ 1e-6) with 10-epoch warmup gUNet
Batch size 16 β€”
Patch size 256Γ—256 Standard in all dehazing papers
Augmentation Random crop + H/V flip + 90Β° rotation Standard
Mixed precision FP16 (autocast + GradScaler) β€”
BN freezing Last 20% of training gUNet paper
Gradient clip 1.0 β€”
Epochs 500 β€”

Latency Estimates

Platform Precision Est. Latency @512Γ—512 Est. FPS
Jetson Orin NX (16GB) FP16 TensorRT ~15-20ms 50-65
Jetson Orin NX (16GB) INT8 TensorRT ~8-12ms 80-125
Jetson Orin Nano (8GB) FP16 TensorRT ~25-30ms 33-40
A10G (24GB) FP32 PyTorch ~3-5ms 200+
RTX 3090 FP32 PyTorch ~3-4ms 250+

Estimates based on 6.55 GMACs model size and published benchmarks from gUNet paper (3.39ms @256Γ—256 on RTX3090 for similar architecture). Actual latency depends on TensorRT version and Jetson thermal state.

Files

File Description
model.py Model architecture (MicroDehazeNet)
train.py Full training script with data download
jetson_inference.py Real-time inference for Jetson (TensorRT + PyTorch)
pytorch_model.pth Trained weights (after training)
pytorch_model_fused.pth BN-fused weights for inference
micro_dehaze_net.onnx ONNX export for TensorRT conversion
config.json Model config and training metrics

Architecture Details

gConvBlock

Input β†’ BatchNorm β†’ PWConv(expand 2Γ—) β†’ DWConv(k=5) β†’ SimpleGate β†’ PWConv(project) β†’ + Input
  • SimpleGate: splits channels in half, multiplies elementwise (no Sigmoid)
  • DWConv k=5: per-channel convolution for spatial mixing
  • PWConv: 1Γ—1 convolutions for channel mixing
  • BatchNorm: fuses into PWConv at inference time

SK Fusion (Skip Connection)

skip + decoder β†’ GAP β†’ FC β†’ ReLU β†’ FC β†’ Split β†’ Softmax β†’ weighted sum

Adaptively weights encoder skip features vs decoder features.

References

  1. gUNet: "Rethinking Performance Gains in Image Dehazing Networks" β€” Wu et al., 2022 (arxiv:2209.11448)
  2. NAFNet: "Simple Baselines for Image Restoration" β€” Chen et al., 2022 (arxiv:2204.04676)
  3. DEA-Net: "Detail-Enhanced Convolution and Content-Guided Attention" β€” Chen et al., 2023 (arxiv:2301.04805)
  4. Haze4K: "PMNet" β€” Ye et al., 2022

License

MIT

Downloads last month
8
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support

Papers for Vive-k-kumar/micro-dehaze-net