""" Stage 4: Query Language and Steering Interface Provides a formal query language for navigating the program space by composing concepts. Users express preferences as concept coordinates, and the system steers LLM generation accordingly. The query language supports: - Single concept selection: steer("recursive", strength=0.8) - Concept composition: steer(recursive=0.8, space_efficient=0.6) - Concept negation: steer(mutation=-0.5) (avoid mutation) - Region queries: select(region_where(recursive > 0.5, fast_execution > 0.7)) - Lattice navigation: refine(current, add_concept="memoization") """ from __future__ import annotations import logging import re from dataclasses import dataclass, field from typing import Any, Optional, Union import numpy as np from reason_first_program.program_space import Program, ProgramSpace from reason_first_program.concepts import Concept, ConceptSet from reason_first_program.embeddings import ( ConceptEmbeddingSpace, GCAVEmbedding, MSRSSteering, ) logger = logging.getLogger(__name__) @dataclass class ConceptQuery: """ A query in the concept space. A query is a weighted combination of concepts that defines a target region in the program space. The system either: 1. Selects existing programs nearest to this region, or 2. Steers generation toward this region. Formally: q = Σ_i w_i · v_i where v_i is concept i's activation vector """ weights: dict[str, float] = field(default_factory=dict) constraints: dict[str, tuple[str, float]] = field(default_factory=dict) # constraints: {concept_name: (operator, threshold)} e.g., {"recursive": (">", 0.5)} metadata: dict[str, Any] = field(default_factory=dict) def __repr__(self) -> str: parts = [] for name, weight in sorted(self.weights.items(), key=lambda x: -abs(x[1])): if weight > 0: parts.append(f"+{weight:.1f}·{name}") else: parts.append(f"{weight:.1f}·{name}") for name, (op, val) in self.constraints.items(): parts.append(f"{name}{op}{val:.1f}") return f"Query({', '.join(parts)})" @property def concept_vector(self) -> dict[str, float]: """The query as a concept-space direction vector.""" return self.weights.copy() def matches(self, concept_scores: dict[str, float]) -> bool: """Check if a program's concept scores satisfy the query constraints.""" for name, (op, threshold) in self.constraints.items(): score = concept_scores.get(name, 0.0) if op == ">" and not (score > threshold): return False elif op == ">=" and not (score >= threshold): return False elif op == "<" and not (score < threshold): return False elif op == "<=" and not (score <= threshold): return False elif op == "==" and not (abs(score - threshold) < 0.05): return False return True def distance_to(self, concept_scores: dict[str, float]) -> float: """ Compute distance from a program's concept scores to this query. Lower = more aligned with query. """ total = 0.0 for name, target_weight in self.weights.items(): actual = concept_scores.get(name, 0.0) # Distance weighted by how strongly we care about this concept total += abs(target_weight) * (actual - (1.0 if target_weight > 0 else 0.0)) ** 2 return total class QueryLanguage: """ Parser and builder for concept queries. Supports a simple DSL: "recursive > 0.5 AND fast_execution > 0.7" "recursive=0.8, space_efficient=0.6, mutation=-0.3" "LIKE program_id_abc123" (find programs similar to a reference) "NOT mutation" (avoid mutation) """ def __init__(self, concept_set: ConceptSet): self.concept_set = concept_set def parse(self, query_str: str) -> ConceptQuery: """Parse a query string into a ConceptQuery.""" query = ConceptQuery() # Handle comma-separated weight assignments: "recursive=0.8, mutation=-0.3" weight_pattern = r"(\w+)\s*=\s*(-?\d+\.?\d*)" for match in re.finditer(weight_pattern, query_str): name = match.group(1) weight = float(match.group(2)) if self.concept_set.get_by_name(name): query.weights[name] = weight # Handle constraint expressions: "recursive > 0.5" constraint_pattern = r"(\w+)\s*(>=|<=|>|<|==)\s*(-?\d+\.?\d*)" for match in re.finditer(constraint_pattern, query_str): name = match.group(1) op = match.group(2) threshold = float(match.group(3)) if name not in query.weights: # Don't double-count if self.concept_set.get_by_name(name): query.constraints[name] = (op, threshold) # Handle NOT: "NOT mutation" not_pattern = r"NOT\s+(\w+)" for match in re.finditer(not_pattern, query_str, re.IGNORECASE): name = match.group(1) if self.concept_set.get_by_name(name): query.weights[name] = query.weights.get(name, -1.0) return query def build(self, **concept_weights: float) -> ConceptQuery: """Build a query from keyword arguments.""" validated = {} for name, weight in concept_weights.items(): if self.concept_set.get_by_name(name): validated[name] = weight else: logger.warning(f"Unknown concept: {name}") return ConceptQuery(weights=validated) def constrain(self, **constraints: str) -> ConceptQuery: """ Build a constraint query. Example: constrain(recursive=">0.5", fast_execution=">=0.7") """ query = ConceptQuery() for name, expr in constraints.items(): if not self.concept_set.get_by_name(name): logger.warning(f"Unknown concept: {name}") continue match = re.match(r"(>=|<=|>|<|==)?\s*(-?\d+\.?\d*)", expr) if match: op = match.group(1) or ">" threshold = float(match.group(2)) query.constraints[name] = (op, threshold) return query class SteeringEngine: """ Engine for steering program generation using concept queries. Two modes: 1. Selection: Given a ProgramSpace and a query, rank/filter programs 2. Generation: Given a query, steer LLM hidden states during generation Selection uses concept scores directly. Generation uses GCAV vectors (e' = e + ε·v) or MSRS orthogonal steering. """ def __init__( self, concept_set: ConceptSet, embedding_space: Optional[ConceptEmbeddingSpace] = None, gcav: Optional[GCAVEmbedding] = None, msrs: Optional[MSRSSteering] = None, ): self.concept_set = concept_set self.embedding_space = embedding_space self.gcav = gcav self.msrs = msrs self.query_language = QueryLanguage(concept_set) # ---- Selection Mode ---- def select( self, space: ProgramSpace, query: Union[ConceptQuery, str], top_k: int = 10, ) -> list[tuple[Program, float]]: """ Select programs from the space that best match the query. Returns list of (program, relevance_score) tuples, sorted by relevance. """ if isinstance(query, str): query = self.query_language.parse(query) scored: list[tuple[Program, float]] = [] for program in space.valid_programs: concept_scores = self.concept_set.score_program(program) # Check hard constraints if not query.matches(concept_scores): continue # Compute soft relevance score relevance = self._compute_relevance(concept_scores, query) scored.append((program, relevance)) # Sort by relevance (higher = better match) scored.sort(key=lambda x: x[1], reverse=True) return scored[:top_k] def _compute_relevance( self, concept_scores: dict[str, float], query: ConceptQuery, ) -> float: """ Compute relevance of a program to a query. For positive weights: reward high concept scores For negative weights: reward low concept scores """ relevance = 0.0 total_weight = 0.0 for name, target_weight in query.weights.items(): actual = concept_scores.get(name, 0.0) if target_weight > 0: relevance += target_weight * actual else: relevance += abs(target_weight) * (1.0 - actual) total_weight += abs(target_weight) if total_weight > 0: relevance /= total_weight return relevance def filter( self, space: ProgramSpace, query: Union[ConceptQuery, str], ) -> ProgramSpace: """Filter a ProgramSpace to programs matching the query.""" if isinstance(query, str): query = self.query_language.parse(query) filtered = ProgramSpace(space.stub) for program in space.valid_programs: concept_scores = self.concept_set.score_program(program) if query.matches(concept_scores): filtered.add(program) return filtered # ---- Generation Steering Mode ---- def build_steering_vector( self, query: Union[ConceptQuery, str], method: str = "additive", ) -> Optional[np.ndarray]: """ Build a steering vector from a concept query. Args: query: The concept query method: 'additive' (GCAV sum) or 'msrs' (orthogonal subspaces) Returns: Steering vector in activation space, or None if no GCAV available """ if isinstance(query, str): query = self.query_language.parse(query) if method == "additive" and self.gcav is not None: # Simple additive: v_steer = Σ w_i · v_i steer = np.zeros_like( next(iter(self.gcav.concept_vectors.values())) ) for name, weight in query.weights.items(): if name in self.gcav.concept_vectors: steer += weight * self.gcav.concept_vectors[name] return steer elif method == "msrs" and self.msrs is not None: # Use MSRS orthogonal steering base = np.zeros(self.msrs.S_align.shape[1]) return self.msrs.steer(base, query.weights) - base return None def steer_prompt( self, query: Union[ConceptQuery, str], base_prompt: str, ) -> str: """ Augment a generation prompt with concept-steering instructions. This is a lightweight steering approach that works with any LLM API (no hidden state access needed). For stronger steering, use build_steering_vector with activation-level intervention. """ if isinstance(query, str): query = self.query_language.parse(query) concept_instructions = [] for name, weight in sorted( query.weights.items(), key=lambda x: -abs(x[1]) ): concept = self.concept_set.get_by_name(name) if concept is None: continue if weight > 0.5: concept_instructions.append( f"STRONGLY PREFER: {concept.description}" ) elif weight > 0: concept_instructions.append( f"PREFER: {concept.description}" ) elif weight < -0.5: concept_instructions.append( f"STRONGLY AVOID: {concept.description}" ) elif weight < 0: concept_instructions.append( f"AVOID: {concept.description}" ) for name, (op, threshold) in query.constraints.items(): concept = self.concept_set.get_by_name(name) if concept: concept_instructions.append( f"CONSTRAINT: {concept.description} ({op} {threshold})" ) if not concept_instructions: return base_prompt steering_block = "\n".join( f" - {inst}" for inst in concept_instructions ) return ( f"{base_prompt}\n\n" f"CONCEPT STEERING INSTRUCTIONS:\n{steering_block}\n\n" f"Follow the above concept preferences when implementing." ) # ---- Exploration Mode ---- def explore_neighbors( self, program: Program, space: ProgramSpace, n_neighbors: int = 5, ) -> list[tuple[Program, float, dict[str, float]]]: """ Find programs in the space that are conceptually nearby. Returns list of (program, distance, concept_diff) tuples. concept_diff shows which concepts differ most. """ ref_scores = self.concept_set.score_program(program) neighbors: list[tuple[Program, float, dict[str, float]]] = [] for other in space.valid_programs: if other.program_id == program.program_id: continue other_scores = self.concept_set.score_program(other) # Euclidean distance in concept space diff = {} dist_sq = 0.0 for name in set(ref_scores) | set(other_scores): d = ref_scores.get(name, 0.0) - other_scores.get(name, 0.0) if abs(d) > 0.01: diff[name] = d dist_sq += d ** 2 neighbors.append((other, dist_sq ** 0.5, diff)) neighbors.sort(key=lambda x: x[1]) return neighbors[:n_neighbors] def concept_boundary_programs( self, concept_name: str, space: ProgramSpace, n_per_side: int = 3, ) -> dict[str, list[Program]]: """ Find programs at the boundary of a concept. Returns programs just inside and just outside the concept region. """ concept = self.concept_set.get_by_name(concept_name) if concept is None: return {"inside": [], "outside": []} inside: list[tuple[Program, float]] = [] outside: list[tuple[Program, float]] = [] for program in space.valid_programs: score = concept.score(program) if score > 0.5: inside.append((program, score)) else: outside.append((program, score)) # Sort inside by score ascending (closest to boundary) inside.sort(key=lambda x: x[1]) # Sort outside by score descending (closest to boundary) outside.sort(key=lambda x: x[1], reverse=True) return { "inside": [p for p, _ in inside[:n_per_side]], "outside": [p for p, _ in outside[:n_per_side]], }