ryanyen22's picture
feat: add reason_first_program/steering.py
4839aff verified
"""
Stage 4: Query Language and Steering Interface
Provides a formal query language for navigating the program space by composing
concepts. Users express preferences as concept coordinates, and the system
steers LLM generation accordingly.
The query language supports:
- Single concept selection: steer("recursive", strength=0.8)
- Concept composition: steer(recursive=0.8, space_efficient=0.6)
- Concept negation: steer(mutation=-0.5) (avoid mutation)
- Region queries: select(region_where(recursive > 0.5, fast_execution > 0.7))
- Lattice navigation: refine(current, add_concept="memoization")
"""
from __future__ import annotations
import logging
import re
from dataclasses import dataclass, field
from typing import Any, Optional, Union
import numpy as np
from reason_first_program.program_space import Program, ProgramSpace
from reason_first_program.concepts import Concept, ConceptSet
from reason_first_program.embeddings import (
ConceptEmbeddingSpace,
GCAVEmbedding,
MSRSSteering,
)
logger = logging.getLogger(__name__)
@dataclass
class ConceptQuery:
"""
A query in the concept space.
A query is a weighted combination of concepts that defines a target
region in the program space. The system either:
1. Selects existing programs nearest to this region, or
2. Steers generation toward this region.
Formally: q = Σ_i w_i · v_i where v_i is concept i's activation vector
"""
weights: dict[str, float] = field(default_factory=dict)
constraints: dict[str, tuple[str, float]] = field(default_factory=dict)
# constraints: {concept_name: (operator, threshold)} e.g., {"recursive": (">", 0.5)}
metadata: dict[str, Any] = field(default_factory=dict)
def __repr__(self) -> str:
parts = []
for name, weight in sorted(self.weights.items(), key=lambda x: -abs(x[1])):
if weight > 0:
parts.append(f"+{weight:.1f}·{name}")
else:
parts.append(f"{weight:.1f}·{name}")
for name, (op, val) in self.constraints.items():
parts.append(f"{name}{op}{val:.1f}")
return f"Query({', '.join(parts)})"
@property
def concept_vector(self) -> dict[str, float]:
"""The query as a concept-space direction vector."""
return self.weights.copy()
def matches(self, concept_scores: dict[str, float]) -> bool:
"""Check if a program's concept scores satisfy the query constraints."""
for name, (op, threshold) in self.constraints.items():
score = concept_scores.get(name, 0.0)
if op == ">" and not (score > threshold):
return False
elif op == ">=" and not (score >= threshold):
return False
elif op == "<" and not (score < threshold):
return False
elif op == "<=" and not (score <= threshold):
return False
elif op == "==" and not (abs(score - threshold) < 0.05):
return False
return True
def distance_to(self, concept_scores: dict[str, float]) -> float:
"""
Compute distance from a program's concept scores to this query.
Lower = more aligned with query.
"""
total = 0.0
for name, target_weight in self.weights.items():
actual = concept_scores.get(name, 0.0)
# Distance weighted by how strongly we care about this concept
total += abs(target_weight) * (actual - (1.0 if target_weight > 0 else 0.0)) ** 2
return total
class QueryLanguage:
"""
Parser and builder for concept queries.
Supports a simple DSL:
"recursive > 0.5 AND fast_execution > 0.7"
"recursive=0.8, space_efficient=0.6, mutation=-0.3"
"LIKE program_id_abc123" (find programs similar to a reference)
"NOT mutation" (avoid mutation)
"""
def __init__(self, concept_set: ConceptSet):
self.concept_set = concept_set
def parse(self, query_str: str) -> ConceptQuery:
"""Parse a query string into a ConceptQuery."""
query = ConceptQuery()
# Handle comma-separated weight assignments: "recursive=0.8, mutation=-0.3"
weight_pattern = r"(\w+)\s*=\s*(-?\d+\.?\d*)"
for match in re.finditer(weight_pattern, query_str):
name = match.group(1)
weight = float(match.group(2))
if self.concept_set.get_by_name(name):
query.weights[name] = weight
# Handle constraint expressions: "recursive > 0.5"
constraint_pattern = r"(\w+)\s*(>=|<=|>|<|==)\s*(-?\d+\.?\d*)"
for match in re.finditer(constraint_pattern, query_str):
name = match.group(1)
op = match.group(2)
threshold = float(match.group(3))
if name not in query.weights: # Don't double-count
if self.concept_set.get_by_name(name):
query.constraints[name] = (op, threshold)
# Handle NOT: "NOT mutation"
not_pattern = r"NOT\s+(\w+)"
for match in re.finditer(not_pattern, query_str, re.IGNORECASE):
name = match.group(1)
if self.concept_set.get_by_name(name):
query.weights[name] = query.weights.get(name, -1.0)
return query
def build(self, **concept_weights: float) -> ConceptQuery:
"""Build a query from keyword arguments."""
validated = {}
for name, weight in concept_weights.items():
if self.concept_set.get_by_name(name):
validated[name] = weight
else:
logger.warning(f"Unknown concept: {name}")
return ConceptQuery(weights=validated)
def constrain(self, **constraints: str) -> ConceptQuery:
"""
Build a constraint query.
Example: constrain(recursive=">0.5", fast_execution=">=0.7")
"""
query = ConceptQuery()
for name, expr in constraints.items():
if not self.concept_set.get_by_name(name):
logger.warning(f"Unknown concept: {name}")
continue
match = re.match(r"(>=|<=|>|<|==)?\s*(-?\d+\.?\d*)", expr)
if match:
op = match.group(1) or ">"
threshold = float(match.group(2))
query.constraints[name] = (op, threshold)
return query
class SteeringEngine:
"""
Engine for steering program generation using concept queries.
Two modes:
1. Selection: Given a ProgramSpace and a query, rank/filter programs
2. Generation: Given a query, steer LLM hidden states during generation
Selection uses concept scores directly.
Generation uses GCAV vectors (e' = e + ε·v) or MSRS orthogonal steering.
"""
def __init__(
self,
concept_set: ConceptSet,
embedding_space: Optional[ConceptEmbeddingSpace] = None,
gcav: Optional[GCAVEmbedding] = None,
msrs: Optional[MSRSSteering] = None,
):
self.concept_set = concept_set
self.embedding_space = embedding_space
self.gcav = gcav
self.msrs = msrs
self.query_language = QueryLanguage(concept_set)
# ---- Selection Mode ----
def select(
self,
space: ProgramSpace,
query: Union[ConceptQuery, str],
top_k: int = 10,
) -> list[tuple[Program, float]]:
"""
Select programs from the space that best match the query.
Returns list of (program, relevance_score) tuples, sorted by relevance.
"""
if isinstance(query, str):
query = self.query_language.parse(query)
scored: list[tuple[Program, float]] = []
for program in space.valid_programs:
concept_scores = self.concept_set.score_program(program)
# Check hard constraints
if not query.matches(concept_scores):
continue
# Compute soft relevance score
relevance = self._compute_relevance(concept_scores, query)
scored.append((program, relevance))
# Sort by relevance (higher = better match)
scored.sort(key=lambda x: x[1], reverse=True)
return scored[:top_k]
def _compute_relevance(
self,
concept_scores: dict[str, float],
query: ConceptQuery,
) -> float:
"""
Compute relevance of a program to a query.
For positive weights: reward high concept scores
For negative weights: reward low concept scores
"""
relevance = 0.0
total_weight = 0.0
for name, target_weight in query.weights.items():
actual = concept_scores.get(name, 0.0)
if target_weight > 0:
relevance += target_weight * actual
else:
relevance += abs(target_weight) * (1.0 - actual)
total_weight += abs(target_weight)
if total_weight > 0:
relevance /= total_weight
return relevance
def filter(
self,
space: ProgramSpace,
query: Union[ConceptQuery, str],
) -> ProgramSpace:
"""Filter a ProgramSpace to programs matching the query."""
if isinstance(query, str):
query = self.query_language.parse(query)
filtered = ProgramSpace(space.stub)
for program in space.valid_programs:
concept_scores = self.concept_set.score_program(program)
if query.matches(concept_scores):
filtered.add(program)
return filtered
# ---- Generation Steering Mode ----
def build_steering_vector(
self,
query: Union[ConceptQuery, str],
method: str = "additive",
) -> Optional[np.ndarray]:
"""
Build a steering vector from a concept query.
Args:
query: The concept query
method: 'additive' (GCAV sum) or 'msrs' (orthogonal subspaces)
Returns:
Steering vector in activation space, or None if no GCAV available
"""
if isinstance(query, str):
query = self.query_language.parse(query)
if method == "additive" and self.gcav is not None:
# Simple additive: v_steer = Σ w_i · v_i
steer = np.zeros_like(
next(iter(self.gcav.concept_vectors.values()))
)
for name, weight in query.weights.items():
if name in self.gcav.concept_vectors:
steer += weight * self.gcav.concept_vectors[name]
return steer
elif method == "msrs" and self.msrs is not None:
# Use MSRS orthogonal steering
base = np.zeros(self.msrs.S_align.shape[1])
return self.msrs.steer(base, query.weights) - base
return None
def steer_prompt(
self,
query: Union[ConceptQuery, str],
base_prompt: str,
) -> str:
"""
Augment a generation prompt with concept-steering instructions.
This is a lightweight steering approach that works with any LLM API
(no hidden state access needed). For stronger steering, use
build_steering_vector with activation-level intervention.
"""
if isinstance(query, str):
query = self.query_language.parse(query)
concept_instructions = []
for name, weight in sorted(
query.weights.items(), key=lambda x: -abs(x[1])
):
concept = self.concept_set.get_by_name(name)
if concept is None:
continue
if weight > 0.5:
concept_instructions.append(
f"STRONGLY PREFER: {concept.description}"
)
elif weight > 0:
concept_instructions.append(
f"PREFER: {concept.description}"
)
elif weight < -0.5:
concept_instructions.append(
f"STRONGLY AVOID: {concept.description}"
)
elif weight < 0:
concept_instructions.append(
f"AVOID: {concept.description}"
)
for name, (op, threshold) in query.constraints.items():
concept = self.concept_set.get_by_name(name)
if concept:
concept_instructions.append(
f"CONSTRAINT: {concept.description} ({op} {threshold})"
)
if not concept_instructions:
return base_prompt
steering_block = "\n".join(
f" - {inst}" for inst in concept_instructions
)
return (
f"{base_prompt}\n\n"
f"CONCEPT STEERING INSTRUCTIONS:\n{steering_block}\n\n"
f"Follow the above concept preferences when implementing."
)
# ---- Exploration Mode ----
def explore_neighbors(
self,
program: Program,
space: ProgramSpace,
n_neighbors: int = 5,
) -> list[tuple[Program, float, dict[str, float]]]:
"""
Find programs in the space that are conceptually nearby.
Returns list of (program, distance, concept_diff) tuples.
concept_diff shows which concepts differ most.
"""
ref_scores = self.concept_set.score_program(program)
neighbors: list[tuple[Program, float, dict[str, float]]] = []
for other in space.valid_programs:
if other.program_id == program.program_id:
continue
other_scores = self.concept_set.score_program(other)
# Euclidean distance in concept space
diff = {}
dist_sq = 0.0
for name in set(ref_scores) | set(other_scores):
d = ref_scores.get(name, 0.0) - other_scores.get(name, 0.0)
if abs(d) > 0.01:
diff[name] = d
dist_sq += d ** 2
neighbors.append((other, dist_sq ** 0.5, diff))
neighbors.sort(key=lambda x: x[1])
return neighbors[:n_neighbors]
def concept_boundary_programs(
self,
concept_name: str,
space: ProgramSpace,
n_per_side: int = 3,
) -> dict[str, list[Program]]:
"""
Find programs at the boundary of a concept.
Returns programs just inside and just outside the concept region.
"""
concept = self.concept_set.get_by_name(concept_name)
if concept is None:
return {"inside": [], "outside": []}
inside: list[tuple[Program, float]] = []
outside: list[tuple[Program, float]] = []
for program in space.valid_programs:
score = concept.score(program)
if score > 0.5:
inside.append((program, score))
else:
outside.append((program, score))
# Sort inside by score ascending (closest to boundary)
inside.sort(key=lambda x: x[1])
# Sort outside by score descending (closest to boundary)
outside.sort(key=lambda x: x[1], reverse=True)
return {
"inside": [p for p, _ in inside[:n_per_side]],
"outside": [p for p, _ in outside[:n_per_side]],
}