File size: 15,346 Bytes
4839aff | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 | """
Stage 4: Query Language and Steering Interface
Provides a formal query language for navigating the program space by composing
concepts. Users express preferences as concept coordinates, and the system
steers LLM generation accordingly.
The query language supports:
- Single concept selection: steer("recursive", strength=0.8)
- Concept composition: steer(recursive=0.8, space_efficient=0.6)
- Concept negation: steer(mutation=-0.5) (avoid mutation)
- Region queries: select(region_where(recursive > 0.5, fast_execution > 0.7))
- Lattice navigation: refine(current, add_concept="memoization")
"""
from __future__ import annotations
import logging
import re
from dataclasses import dataclass, field
from typing import Any, Optional, Union
import numpy as np
from reason_first_program.program_space import Program, ProgramSpace
from reason_first_program.concepts import Concept, ConceptSet
from reason_first_program.embeddings import (
ConceptEmbeddingSpace,
GCAVEmbedding,
MSRSSteering,
)
logger = logging.getLogger(__name__)
@dataclass
class ConceptQuery:
"""
A query in the concept space.
A query is a weighted combination of concepts that defines a target
region in the program space. The system either:
1. Selects existing programs nearest to this region, or
2. Steers generation toward this region.
Formally: q = 危_i w_i 路 v_i where v_i is concept i's activation vector
"""
weights: dict[str, float] = field(default_factory=dict)
constraints: dict[str, tuple[str, float]] = field(default_factory=dict)
# constraints: {concept_name: (operator, threshold)} e.g., {"recursive": (">", 0.5)}
metadata: dict[str, Any] = field(default_factory=dict)
def __repr__(self) -> str:
parts = []
for name, weight in sorted(self.weights.items(), key=lambda x: -abs(x[1])):
if weight > 0:
parts.append(f"+{weight:.1f}路{name}")
else:
parts.append(f"{weight:.1f}路{name}")
for name, (op, val) in self.constraints.items():
parts.append(f"{name}{op}{val:.1f}")
return f"Query({', '.join(parts)})"
@property
def concept_vector(self) -> dict[str, float]:
"""The query as a concept-space direction vector."""
return self.weights.copy()
def matches(self, concept_scores: dict[str, float]) -> bool:
"""Check if a program's concept scores satisfy the query constraints."""
for name, (op, threshold) in self.constraints.items():
score = concept_scores.get(name, 0.0)
if op == ">" and not (score > threshold):
return False
elif op == ">=" and not (score >= threshold):
return False
elif op == "<" and not (score < threshold):
return False
elif op == "<=" and not (score <= threshold):
return False
elif op == "==" and not (abs(score - threshold) < 0.05):
return False
return True
def distance_to(self, concept_scores: dict[str, float]) -> float:
"""
Compute distance from a program's concept scores to this query.
Lower = more aligned with query.
"""
total = 0.0
for name, target_weight in self.weights.items():
actual = concept_scores.get(name, 0.0)
# Distance weighted by how strongly we care about this concept
total += abs(target_weight) * (actual - (1.0 if target_weight > 0 else 0.0)) ** 2
return total
class QueryLanguage:
"""
Parser and builder for concept queries.
Supports a simple DSL:
"recursive > 0.5 AND fast_execution > 0.7"
"recursive=0.8, space_efficient=0.6, mutation=-0.3"
"LIKE program_id_abc123" (find programs similar to a reference)
"NOT mutation" (avoid mutation)
"""
def __init__(self, concept_set: ConceptSet):
self.concept_set = concept_set
def parse(self, query_str: str) -> ConceptQuery:
"""Parse a query string into a ConceptQuery."""
query = ConceptQuery()
# Handle comma-separated weight assignments: "recursive=0.8, mutation=-0.3"
weight_pattern = r"(\w+)\s*=\s*(-?\d+\.?\d*)"
for match in re.finditer(weight_pattern, query_str):
name = match.group(1)
weight = float(match.group(2))
if self.concept_set.get_by_name(name):
query.weights[name] = weight
# Handle constraint expressions: "recursive > 0.5"
constraint_pattern = r"(\w+)\s*(>=|<=|>|<|==)\s*(-?\d+\.?\d*)"
for match in re.finditer(constraint_pattern, query_str):
name = match.group(1)
op = match.group(2)
threshold = float(match.group(3))
if name not in query.weights: # Don't double-count
if self.concept_set.get_by_name(name):
query.constraints[name] = (op, threshold)
# Handle NOT: "NOT mutation"
not_pattern = r"NOT\s+(\w+)"
for match in re.finditer(not_pattern, query_str, re.IGNORECASE):
name = match.group(1)
if self.concept_set.get_by_name(name):
query.weights[name] = query.weights.get(name, -1.0)
return query
def build(self, **concept_weights: float) -> ConceptQuery:
"""Build a query from keyword arguments."""
validated = {}
for name, weight in concept_weights.items():
if self.concept_set.get_by_name(name):
validated[name] = weight
else:
logger.warning(f"Unknown concept: {name}")
return ConceptQuery(weights=validated)
def constrain(self, **constraints: str) -> ConceptQuery:
"""
Build a constraint query.
Example: constrain(recursive=">0.5", fast_execution=">=0.7")
"""
query = ConceptQuery()
for name, expr in constraints.items():
if not self.concept_set.get_by_name(name):
logger.warning(f"Unknown concept: {name}")
continue
match = re.match(r"(>=|<=|>|<|==)?\s*(-?\d+\.?\d*)", expr)
if match:
op = match.group(1) or ">"
threshold = float(match.group(2))
query.constraints[name] = (op, threshold)
return query
class SteeringEngine:
"""
Engine for steering program generation using concept queries.
Two modes:
1. Selection: Given a ProgramSpace and a query, rank/filter programs
2. Generation: Given a query, steer LLM hidden states during generation
Selection uses concept scores directly.
Generation uses GCAV vectors (e' = e + 蔚路v) or MSRS orthogonal steering.
"""
def __init__(
self,
concept_set: ConceptSet,
embedding_space: Optional[ConceptEmbeddingSpace] = None,
gcav: Optional[GCAVEmbedding] = None,
msrs: Optional[MSRSSteering] = None,
):
self.concept_set = concept_set
self.embedding_space = embedding_space
self.gcav = gcav
self.msrs = msrs
self.query_language = QueryLanguage(concept_set)
# ---- Selection Mode ----
def select(
self,
space: ProgramSpace,
query: Union[ConceptQuery, str],
top_k: int = 10,
) -> list[tuple[Program, float]]:
"""
Select programs from the space that best match the query.
Returns list of (program, relevance_score) tuples, sorted by relevance.
"""
if isinstance(query, str):
query = self.query_language.parse(query)
scored: list[tuple[Program, float]] = []
for program in space.valid_programs:
concept_scores = self.concept_set.score_program(program)
# Check hard constraints
if not query.matches(concept_scores):
continue
# Compute soft relevance score
relevance = self._compute_relevance(concept_scores, query)
scored.append((program, relevance))
# Sort by relevance (higher = better match)
scored.sort(key=lambda x: x[1], reverse=True)
return scored[:top_k]
def _compute_relevance(
self,
concept_scores: dict[str, float],
query: ConceptQuery,
) -> float:
"""
Compute relevance of a program to a query.
For positive weights: reward high concept scores
For negative weights: reward low concept scores
"""
relevance = 0.0
total_weight = 0.0
for name, target_weight in query.weights.items():
actual = concept_scores.get(name, 0.0)
if target_weight > 0:
relevance += target_weight * actual
else:
relevance += abs(target_weight) * (1.0 - actual)
total_weight += abs(target_weight)
if total_weight > 0:
relevance /= total_weight
return relevance
def filter(
self,
space: ProgramSpace,
query: Union[ConceptQuery, str],
) -> ProgramSpace:
"""Filter a ProgramSpace to programs matching the query."""
if isinstance(query, str):
query = self.query_language.parse(query)
filtered = ProgramSpace(space.stub)
for program in space.valid_programs:
concept_scores = self.concept_set.score_program(program)
if query.matches(concept_scores):
filtered.add(program)
return filtered
# ---- Generation Steering Mode ----
def build_steering_vector(
self,
query: Union[ConceptQuery, str],
method: str = "additive",
) -> Optional[np.ndarray]:
"""
Build a steering vector from a concept query.
Args:
query: The concept query
method: 'additive' (GCAV sum) or 'msrs' (orthogonal subspaces)
Returns:
Steering vector in activation space, or None if no GCAV available
"""
if isinstance(query, str):
query = self.query_language.parse(query)
if method == "additive" and self.gcav is not None:
# Simple additive: v_steer = 危 w_i 路 v_i
steer = np.zeros_like(
next(iter(self.gcav.concept_vectors.values()))
)
for name, weight in query.weights.items():
if name in self.gcav.concept_vectors:
steer += weight * self.gcav.concept_vectors[name]
return steer
elif method == "msrs" and self.msrs is not None:
# Use MSRS orthogonal steering
base = np.zeros(self.msrs.S_align.shape[1])
return self.msrs.steer(base, query.weights) - base
return None
def steer_prompt(
self,
query: Union[ConceptQuery, str],
base_prompt: str,
) -> str:
"""
Augment a generation prompt with concept-steering instructions.
This is a lightweight steering approach that works with any LLM API
(no hidden state access needed). For stronger steering, use
build_steering_vector with activation-level intervention.
"""
if isinstance(query, str):
query = self.query_language.parse(query)
concept_instructions = []
for name, weight in sorted(
query.weights.items(), key=lambda x: -abs(x[1])
):
concept = self.concept_set.get_by_name(name)
if concept is None:
continue
if weight > 0.5:
concept_instructions.append(
f"STRONGLY PREFER: {concept.description}"
)
elif weight > 0:
concept_instructions.append(
f"PREFER: {concept.description}"
)
elif weight < -0.5:
concept_instructions.append(
f"STRONGLY AVOID: {concept.description}"
)
elif weight < 0:
concept_instructions.append(
f"AVOID: {concept.description}"
)
for name, (op, threshold) in query.constraints.items():
concept = self.concept_set.get_by_name(name)
if concept:
concept_instructions.append(
f"CONSTRAINT: {concept.description} ({op} {threshold})"
)
if not concept_instructions:
return base_prompt
steering_block = "\n".join(
f" - {inst}" for inst in concept_instructions
)
return (
f"{base_prompt}\n\n"
f"CONCEPT STEERING INSTRUCTIONS:\n{steering_block}\n\n"
f"Follow the above concept preferences when implementing."
)
# ---- Exploration Mode ----
def explore_neighbors(
self,
program: Program,
space: ProgramSpace,
n_neighbors: int = 5,
) -> list[tuple[Program, float, dict[str, float]]]:
"""
Find programs in the space that are conceptually nearby.
Returns list of (program, distance, concept_diff) tuples.
concept_diff shows which concepts differ most.
"""
ref_scores = self.concept_set.score_program(program)
neighbors: list[tuple[Program, float, dict[str, float]]] = []
for other in space.valid_programs:
if other.program_id == program.program_id:
continue
other_scores = self.concept_set.score_program(other)
# Euclidean distance in concept space
diff = {}
dist_sq = 0.0
for name in set(ref_scores) | set(other_scores):
d = ref_scores.get(name, 0.0) - other_scores.get(name, 0.0)
if abs(d) > 0.01:
diff[name] = d
dist_sq += d ** 2
neighbors.append((other, dist_sq ** 0.5, diff))
neighbors.sort(key=lambda x: x[1])
return neighbors[:n_neighbors]
def concept_boundary_programs(
self,
concept_name: str,
space: ProgramSpace,
n_per_side: int = 3,
) -> dict[str, list[Program]]:
"""
Find programs at the boundary of a concept.
Returns programs just inside and just outside the concept region.
"""
concept = self.concept_set.get_by_name(concept_name)
if concept is None:
return {"inside": [], "outside": []}
inside: list[tuple[Program, float]] = []
outside: list[tuple[Program, float]] = []
for program in space.valid_programs:
score = concept.score(program)
if score > 0.5:
inside.append((program, score))
else:
outside.append((program, score))
# Sort inside by score ascending (closest to boundary)
inside.sort(key=lambda x: x[1])
# Sort outside by score descending (closest to boundary)
outside.sort(key=lambda x: x[1], reverse=True)
return {
"inside": [p for p, _ in inside[:n_per_side]],
"outside": [p for p, _ in outside[:n_per_side]],
}
|