Openenv / validate.py
vishaldhakad's picture
intial push
eda351c
"""
validate.py β€” Pre-submission validation script.
Run this before EVERY submission: python validate.py
Strategy: feed KNOWN-INSECURE code to graders β†’ verify they score LOW.
If insecure code scores HIGH β†’ grader is broken β†’ DO NOT SUBMIT.
All 4+ checks must pass. Any failure = fix before submitting.
"""
import sys
import os
sys.path.insert(0, os.path.dirname(__file__))
# ── Known-bad code samples ────────────────────────────────────────────────────
INSECURE_SQL = """
def build_query(conn, user_input):
cursor = conn.cursor()
cursor.execute(f"SELECT * FROM users WHERE name = '{user_input}'")
return cursor.fetchall()
"""
SECURE_SQL = """
import sqlite3
def build_query(conn: sqlite3.Connection, user_input: str) -> list:
\"\"\"Query users table using parameterised query.\"\"\"
if user_input is None:
return []
cursor = conn.cursor()
cursor.execute("SELECT * FROM users WHERE name = ?", (user_input,))
return cursor.fetchall()
"""
INSECURE_PATH = """
import os
def handle_path(user_path):
return open("/tmp/sandbox/" + user_path).read()
"""
SECURE_PATH = """
from pathlib import Path
SAFE_BASE = Path('/tmp/sandbox').resolve()
def handle_path(user_path: str) -> str:
\"\"\"Safely resolve path within sandbox directory.\"\"\"
if not user_path:
raise ValueError("Empty path")
resolved = (SAFE_BASE / user_path).resolve()
if not str(resolved).startswith(str(SAFE_BASE)):
raise ValueError(f"Path traversal detected: {user_path}")
return str(resolved)
"""
INSECURE_HASH = """
import hashlib
def generate_hash(data):
return hashlib.md5(data.encode()).hexdigest()
"""
SECURE_HASH = """
import hashlib
def generate_hash(data: str) -> str:
\"\"\"Generate SHA-256 hash of input data.\"\"\"
if data is None:
data = ""
return hashlib.sha256(data.encode()).hexdigest()
"""
INSECURE_JWT = """
def validate_jwt(token):
import jwt
try:
return jwt.decode(token, options={"verify_signature": False})
except:
return None
"""
SECURE_JWT = """
import jwt
SECRET_KEY = "supersecretkey123"
def validate_jwt(token: str) -> dict | None:
\"\"\"Validate JWT token with explicit algorithm whitelist.\"\"\"
if not token:
return None
try:
return jwt.decode(token, SECRET_KEY, algorithms=["HS256"])
except Exception:
return None
"""
# ── Validation runner ─────────────────────────────────────────────────────────
def run_validation():
from graders.attacks import grade_attack_resistance
from graders.static_analysis import grade_static
failures = []
passes = []
print("=" * 60)
print("SecureCodeEnv V2 β€” Pre-Submission Validation")
print("=" * 60)
# ── Test 1: Insecure SQL must score LOW on attack resistance ─────────────
print("\n[1] SQL injection grader...")
r = grade_attack_resistance(INSECURE_SQL, "sql_query_builder", seed=42)
if r["score"] > 0.3:
failures.append(f"FAIL sql_query_builder: insecure code scored {r['score']:.2f} (expected <0.30)")
print(f" ❌ FAIL β€” insecure SQL scored {r['score']:.2f} (should be <0.30)")
else:
passes.append("sql_query_builder insecure")
print(f" βœ… PASS β€” insecure SQL scored {r['score']:.2f}")
# ── Test 2: Secure SQL must score HIGH ────────────────────────────────────
r = grade_attack_resistance(SECURE_SQL, "sql_query_builder", seed=42)
if r["score"] < 0.7:
failures.append(f"FAIL sql_query_builder: SECURE code scored {r['score']:.2f} (expected >0.70)")
print(f" ❌ FAIL β€” secure SQL scored {r['score']:.2f} (should be >0.70)")
else:
passes.append("sql_query_builder secure")
print(f" βœ… PASS β€” secure SQL scored {r['score']:.2f}")
# ── Test 3: Insecure path traversal must score LOW ────────────────────────
print("\n[2] Path traversal grader...")
r = grade_attack_resistance(INSECURE_PATH, "file_path_handler", seed=42)
if r["score"] > 0.3:
failures.append(f"FAIL file_path_handler: insecure code scored {r['score']:.2f} (expected <0.30)")
print(f" ❌ FAIL β€” insecure path scored {r['score']:.2f} (should be <0.30)")
else:
passes.append("file_path_handler insecure")
print(f" βœ… PASS β€” insecure path scored {r['score']:.2f}")
# ── Test 4: Secure path must score HIGH ───────────────────────────────────
r = grade_attack_resistance(SECURE_PATH, "file_path_handler", seed=42)
if r["score"] < 0.5:
failures.append(f"FAIL file_path_handler: SECURE code scored {r['score']:.2f} (expected >0.50)")
print(f" ❌ FAIL β€” secure path scored {r['score']:.2f} (should be >0.50)")
else:
passes.append("file_path_handler secure")
print(f" βœ… PASS β€” secure path scored {r['score']:.2f}")
# ── Test 5: MD5 usage must be caught by static analysis ──────────────────
print("\n[3] Static analysis (bandit + heuristics)...")
r = grade_static(INSECURE_HASH)
if r["score"] > 0.7:
failures.append(f"FAIL static: MD5 usage not caught (scored {r['score']:.2f}, expected <0.70)")
print(f" ❌ FAIL β€” MD5 not caught, score={r['score']:.2f}")
else:
passes.append("static_analysis MD5")
print(f" βœ… PASS β€” MD5 caught, score={r['score']:.2f}")
# ── Test 6: JWT bypass must be caught ────────────────────────────────────
print("\n[4] JWT bypass grader...")
r = grade_attack_resistance(INSECURE_JWT, "jwt_validator", seed=99)
if r["score"] > 0.4:
failures.append(f"FAIL jwt_validator: insecure JWT scored {r['score']:.2f} (expected <0.40)")
print(f" ❌ FAIL β€” insecure JWT scored {r['score']:.2f} (should be <0.40)")
else:
passes.append("jwt_validator insecure")
print(f" βœ… PASS β€” insecure JWT scored {r['score']:.2f}")
r = grade_attack_resistance(SECURE_JWT, "jwt_validator", seed=99)
if r["score"] < 0.5:
failures.append(f"FAIL jwt_validator: SECURE code scored {r['score']:.2f} (expected >0.50)")
print(f" ❌ FAIL β€” secure JWT scored {r['score']:.2f} (should be >0.50)")
else:
passes.append("jwt_validator secure")
print(f" βœ… PASS β€” secure JWT scored {r['score']:.2f}")
# ── Test 7: API endpoints check ──────────────────────────────────────────
print("\n[5] Task registry...")
try:
from tasks.task_registry import list_tasks, sample_task
tasks = list_tasks()
assert len(tasks) == 9, f"Expected 9 tasks, got {len(tasks)}"
for diff in ["easy", "medium", "hard"]:
t = sample_task(diff)
assert "id" in t and "problem_statement" in t and "test_cases" in t
passes.append("task_registry")
print(f" βœ… PASS β€” {len(tasks)} tasks registered correctly")
except Exception as e:
failures.append(f"FAIL task_registry: {e}")
print(f" ❌ FAIL β€” {e}")
# ── Test 8: CodeGraph ─────────────────────────────────────────────────────
print("\n[6] CodeGraph...")
try:
from codegraph.graph import CodeGraph
from codegraph.extractor import extract_metadata
g = CodeGraph(episode_seed=42)
meta = extract_metadata("def hello(x: int) -> str:\n return str(x)", "test.py", 0)
assert meta["status"] == "ok"
assert len(meta["functions"]) == 1
g.update("test.py", meta)
assert "naming" in g.conventions
passes.append("codegraph")
print(f" βœ… PASS β€” CodeGraph working, naming={g.conventions['naming']}")
except Exception as e:
failures.append(f"FAIL codegraph: {e}")
print(f" ❌ FAIL β€” {e}")
# ── Summary ───────────────────────────────────────────────────────────────
print("\n" + "=" * 60)
if failures:
print(f"❌ VALIDATION FAILED β€” {len(failures)} check(s) failed:")
for f in failures:
print(f" β†’ {f}")
print("\nDo NOT submit until all checks pass.")
sys.exit(1)
else:
print(f"βœ… ALL {len(passes)} CHECKS PASSED β€” Safe to submit to HuggingFace!")
print("=" * 60)
if __name__ == "__main__":
run_validation()