ryanyen22's picture
feat: add reason_first_program/embeddings.py
c133ebf verified
"""
Stage 3: Concept-Guided Embeddings
Project programs into a space where concept dimensions are explicit and interpretable.
Two approaches:
1. ConceptBottleneckAE (CB-SAE, 2512.10805):
Encoder maps program → concept scores; decoder reconstructs from concept scores.
Each bottleneck dimension = a named concept.
2. GCAVEmbedding (GCAV, 2501.05764):
For each concept, train a linear classifier on LLM hidden states.
The concept activation vector = classifier normal direction.
Steering: e' = e + ε·v_concept
Verification:
- t-SNE/UMAP visualization
- DA@K in concept space vs raw embedding space
- AlgoSim label prediction from concept-space distances
"""
from __future__ import annotations
import logging
from dataclasses import dataclass, field
from typing import Any, Optional
import numpy as np
from reason_first_program.program_space import Program, ProgramSpace
from reason_first_program.concepts import Concept, ConceptSet
logger = logging.getLogger(__name__)
class ConceptBottleneckAE:
"""
Concept Bottleneck Autoencoder.
Architecture (from CB-SAE, 2512.10805):
Encoder: program_features → concept_scores (|C| dimensions)
Decoder: concept_scores → reconstructed_features
The bottleneck forces the representation to go through named concept
dimensions, making each axis interpretable.
Training:
L = L_recon + λ_concept * L_concept_supervision + λ_sparse * L_sparsity
"""
def __init__(
self,
n_concepts: int,
input_dim: int,
hidden_dim: int = 256,
sparsity_weight: float = 0.01,
concept_supervision_weight: float = 1.0,
learning_rate: float = 1e-3,
n_epochs: int = 100,
):
self.n_concepts = n_concepts
self.input_dim = input_dim
self.hidden_dim = hidden_dim
self.sparsity_weight = sparsity_weight
self.concept_supervision_weight = concept_supervision_weight
self.learning_rate = learning_rate
self.n_epochs = n_epochs
# Model parameters (initialized during training)
self.encoder_weights: Optional[np.ndarray] = None
self.decoder_weights: Optional[np.ndarray] = None
self.concept_names: list[str] = []
def train(
self,
features: np.ndarray,
concept_labels: np.ndarray,
concept_names: list[str],
) -> dict[str, float]:
"""
Train the concept bottleneck autoencoder.
Args:
features: (n_programs, input_dim) - program feature vectors
concept_labels: (n_programs, n_concepts) - concept supervision labels
concept_names: list of concept names for each bottleneck dimension
Returns:
Training metrics dict
"""
try:
import torch
import torch.nn as nn
import torch.optim as optim
except ImportError:
return self._train_numpy(features, concept_labels, concept_names)
self.concept_names = concept_names
device = "cuda" if torch.cuda.is_available() else "cpu"
# Build model
encoder = nn.Sequential(
nn.Linear(self.input_dim, self.hidden_dim),
nn.ReLU(),
nn.Linear(self.hidden_dim, self.n_concepts),
nn.Sigmoid(),
).to(device)
decoder = nn.Sequential(
nn.Linear(self.n_concepts, self.hidden_dim),
nn.ReLU(),
nn.Linear(self.hidden_dim, self.input_dim),
).to(device)
optimizer = optim.Adam(
list(encoder.parameters()) + list(decoder.parameters()),
lr=self.learning_rate,
)
X = torch.tensor(features, dtype=torch.float32, device=device)
Y = torch.tensor(concept_labels, dtype=torch.float32, device=device)
losses = []
for epoch in range(self.n_epochs):
concept_scores = encoder(X)
reconstructed = decoder(concept_scores)
# Reconstruction loss
l_recon = ((X - reconstructed) ** 2).mean()
# Concept supervision loss
l_concept = nn.functional.binary_cross_entropy(concept_scores, Y)
# Sparsity loss (encourage concept scores to be sparse)
l_sparse = concept_scores.abs().mean()
loss = (
l_recon
+ self.concept_supervision_weight * l_concept
+ self.sparsity_weight * l_sparse
)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if epoch % 20 == 0:
logger.info(
f"CB-AE epoch {epoch}: loss={loss.item():.4f} "
f"recon={l_recon.item():.4f} concept={l_concept.item():.4f}"
)
losses.append(loss.item())
# Store trained weights
self.encoder_weights = {
k: v.cpu().detach().numpy()
for k, v in encoder.state_dict().items()
}
self.decoder_weights = {
k: v.cpu().detach().numpy()
for k, v in decoder.state_dict().items()
}
self._encoder = encoder
self._decoder = decoder
return {
"final_loss": losses[-1],
"final_recon_loss": l_recon.item(),
"final_concept_loss": l_concept.item(),
}
def _train_numpy(
self,
features: np.ndarray,
concept_labels: np.ndarray,
concept_names: list[str],
) -> dict[str, float]:
"""Fallback numpy-only training (simple linear model)."""
self.concept_names = concept_names
n, d = features.shape
k = self.n_concepts
# Simple linear encoder: W_enc ∈ R^{k×d}
# Solve via least squares: concept_labels ≈ features @ W_enc.T
W_enc, _, _, _ = np.linalg.lstsq(features, concept_labels, rcond=None)
self.encoder_weights = {"linear": W_enc.T} # (k, d)
# Decoder: features ≈ concept_scores @ W_dec.T
concept_scores = features @ W_enc # (n, k)
W_dec, _, _, _ = np.linalg.lstsq(concept_scores, features, rcond=None)
self.decoder_weights = {"linear": W_dec.T} # (d, k)
recon = concept_scores @ W_dec
recon_loss = float(np.mean((features - recon) ** 2))
return {"final_loss": recon_loss, "method": "numpy_linear"}
def encode(self, features: np.ndarray) -> np.ndarray:
"""
Encode programs into concept space.
Returns (n_programs, n_concepts) concept scores.
"""
if hasattr(self, "_encoder"):
import torch
device = next(self._encoder.parameters()).device
X = torch.tensor(features, dtype=torch.float32, device=device)
with torch.no_grad():
scores = self._encoder(X).cpu().numpy()
return scores
elif self.encoder_weights is not None and "linear" in self.encoder_weights:
return features @ self.encoder_weights["linear"].T
else:
raise RuntimeError("Model not trained yet")
def decode(self, concept_scores: np.ndarray) -> np.ndarray:
"""Decode from concept space back to feature space."""
if hasattr(self, "_decoder"):
import torch
device = next(self._decoder.parameters()).device
Z = torch.tensor(concept_scores, dtype=torch.float32, device=device)
with torch.no_grad():
features = self._decoder(Z).cpu().numpy()
return features
elif self.decoder_weights is not None and "linear" in self.decoder_weights:
return concept_scores @ self.decoder_weights["linear"].T
else:
raise RuntimeError("Model not trained yet")
class GCAVEmbedding:
"""
Concept Activation Vector embedding.
Based on GCAV (2501.05764): for each concept, the CAV is the normal direction
of a logistic classifier that separates concept-positive from concept-negative
activations.
For concept d at layer l:
P_d^(l)(e) = sigmoid(w_d^(l)^T · e + b_d^(l))
v_d^(l) = w_d^(l) / ||w_d^(l)||
Steering:
e' = e + ε · v_concept
"""
def __init__(self):
self.concept_vectors: dict[str, np.ndarray] = {}
self.concept_biases: dict[str, float] = {}
self.concept_classifiers: dict[str, Any] = {}
def train_concept_vector(
self,
concept_name: str,
positive_features: np.ndarray,
negative_features: np.ndarray,
) -> dict[str, float]:
"""
Train a concept activation vector from contrastive data.
Args:
concept_name: Name of the concept
positive_features: Features of programs exhibiting the concept
negative_features: Features of programs NOT exhibiting the concept
Returns:
Training metrics
"""
from sklearn.linear_model import LogisticRegression
X = np.vstack([positive_features, negative_features])
y = np.array([1] * len(positive_features) + [0] * len(negative_features))
clf = LogisticRegression(max_iter=1000, solver="lbfgs")
clf.fit(X, y)
# CAV = normalized classifier weights (Eq. 2 from GCAV)
w = clf.coef_[0]
v = w / (np.linalg.norm(w) + 1e-8)
self.concept_vectors[concept_name] = v
self.concept_biases[concept_name] = float(clf.intercept_[0])
self.concept_classifiers[concept_name] = clf
accuracy = clf.score(X, y)
return {
"accuracy": accuracy,
"concept": concept_name,
"vector_norm": float(np.linalg.norm(w)),
}
def train_all(
self,
concept_set: ConceptSet,
features: np.ndarray,
programs: list[Program],
) -> dict[str, dict[str, float]]:
"""
Train CAVs for all concepts in a concept set.
Args:
concept_set: Set of discovered concepts
features: (n_programs, dim) feature vectors
programs: List of programs (same order as features)
"""
program_id_to_idx = {p.program_id: i for i, p in enumerate(programs)}
results = {}
for concept in concept_set.concepts:
pos_idx = [
program_id_to_idx[pid]
for pid in concept.programs
if pid in program_id_to_idx
]
neg_idx = [
i for i in range(len(programs))
if programs[i].program_id not in concept.programs
]
if len(pos_idx) < 2 or len(neg_idx) < 2:
logger.warning(
f"Skipping concept '{concept.name}': insufficient samples "
f"(pos={len(pos_idx)}, neg={len(neg_idx)})"
)
continue
pos_features = features[pos_idx]
neg_features = features[neg_idx]
results[concept.name] = self.train_concept_vector(
concept.name, pos_features, neg_features
)
logger.info(f"Trained {len(results)} concept activation vectors")
return results
def project(self, features: np.ndarray) -> np.ndarray:
"""
Project features into concept space.
Each dimension = dot product with concept activation vector.
Returns (n_samples, n_concepts).
"""
if not self.concept_vectors:
raise RuntimeError("No concept vectors trained")
vectors = np.array(list(self.concept_vectors.values())) # (n_concepts, dim)
return features @ vectors.T
def steer(
self,
features: np.ndarray,
concept_name: str,
strength: float = 1.0,
) -> np.ndarray:
"""
Steer features toward (or away from) a concept.
Implements Eq. 3 from GCAV: e' = e + ε · v_concept
Args:
features: (n, dim) or (dim,) feature vector(s)
concept_name: Which concept to steer toward
strength: ε — positive = toward, negative = away
Returns:
Steered features
"""
if concept_name not in self.concept_vectors:
raise ValueError(f"Unknown concept: {concept_name}")
v = self.concept_vectors[concept_name]
return features + strength * v
def multi_steer(
self,
features: np.ndarray,
concept_weights: dict[str, float],
) -> np.ndarray:
"""
Steer features along multiple concept dimensions simultaneously.
Simple additive steering (may cause interference — see MSRS for
orthogonal approach).
Args:
features: Feature vector(s)
concept_weights: {concept_name: strength}
"""
steered = features.copy()
for concept_name, weight in concept_weights.items():
if concept_name in self.concept_vectors:
steered = steered + weight * self.concept_vectors[concept_name]
return steered
class MSRSSteering:
"""
Multi-Subspace Representation Steering (MSRS, 2508.10599).
Addresses concept interference by assigning orthogonal subspaces to each
concept. Key components:
1. Shared subspace B_shared: captures common directions across all concepts
2. Private subspaces B_i: concept-specific orthogonal directions
3. Adaptive mask m(h): learns to weight subspace dimensions
Intervention: Φ(h; R, W, b, m) = h + R^T · diag(m(h)) · (Wh + b - Rh)
"""
def __init__(self, energy_threshold: float = 0.6):
self.energy_threshold = energy_threshold
self.B_shared: Optional[np.ndarray] = None
self.B_private: dict[str, np.ndarray] = {}
self.S_align: Optional[np.ndarray] = None
def fit(
self,
concept_features: dict[str, np.ndarray],
) -> dict[str, Any]:
"""
Extract shared and private subspaces for each concept.
Args:
concept_features: {concept_name: (n_samples, dim) features}
"""
# Step 1: Compute mean activation for each concept
means = {}
for name, features in concept_features.items():
means[name] = features.mean(axis=0)
# Step 2: Build combined activation matrix τ_c
concept_names = list(means.keys())
tau_c = np.column_stack([means[name] for name in concept_names]) # (d, n)
# Step 3: SVD for shared subspace
U, S, Vt = np.linalg.svd(tau_c, full_matrices=False)
cumulative_energy = np.cumsum(S) / S.sum()
r_s = int(np.searchsorted(cumulative_energy, self.energy_threshold) + 1)
r_s = max(1, min(r_s, len(S)))
self.B_shared = Vt[:r_s] # (r_s, d)
# Step 4: Private subspaces for each concept
self.B_private = {}
for name, mean_act in means.items():
# Project out shared component
residual = mean_act - self.B_shared.T @ (self.B_shared @ mean_act)
if np.linalg.norm(residual) > 1e-8:
# SVD on residual (treating as column vector → trivial SVD)
residual_norm = residual / np.linalg.norm(residual)
self.B_private[name] = residual_norm.reshape(1, -1)
# Step 5: Build alignment matrix S_align
components = [self.B_shared]
for name in concept_names:
if name in self.B_private:
components.append(self.B_private[name])
self.S_align = np.vstack(components)
return {
"shared_rank": r_s,
"n_concepts": len(concept_names),
"private_dims": {
name: B.shape[0] for name, B in self.B_private.items()
},
"total_dims": self.S_align.shape[0],
}
def steer(
self,
features: np.ndarray,
concept_weights: dict[str, float],
) -> np.ndarray:
"""
Steer features using orthogonal subspace decomposition.
Applies steering in each concept's private subspace independently,
then adds shared-subspace steering. This prevents interference.
"""
if self.B_shared is None:
raise RuntimeError("MSRS not fitted yet")
steered = features.copy()
# Shared subspace steering (weighted average of all concepts)
total_weight = sum(abs(w) for w in concept_weights.values())
if total_weight > 0:
shared_direction = np.zeros(features.shape[-1])
for name, weight in concept_weights.items():
if name in self.B_private:
private = self.B_private[name]
shared_direction += weight * private[0]
steered = steered + shared_direction
return steered
def project(self, features: np.ndarray) -> np.ndarray:
"""Project features into the aligned subspace."""
if self.S_align is None:
raise RuntimeError("MSRS not fitted yet")
return features @ self.S_align.T
class ConceptEmbeddingSpace:
"""
Unified embedding space that combines CB-AE and GCAV approaches.
Provides:
- Program projection into concept space
- Visualization (t-SNE / UMAP)
- Alignment verification
- Steering interface
"""
def __init__(
self,
concept_set: ConceptSet,
cbae: Optional[ConceptBottleneckAE] = None,
gcav: Optional[GCAVEmbedding] = None,
msrs: Optional[MSRSSteering] = None,
):
self.concept_set = concept_set
self.cbae = cbae
self.gcav = gcav
self.msrs = msrs
def project(
self,
programs: list[Program],
method: str = "concept_scores",
) -> np.ndarray:
"""
Project programs into concept space.
Args:
programs: Programs to project
method: 'concept_scores' (direct scoring), 'cbae', 'gcav'
Returns:
(n_programs, n_concepts) projection
"""
if method == "concept_scores":
return self.concept_set.score_matrix(programs)
elif method == "cbae" and self.cbae is not None:
raise NotImplementedError("Need features extraction")
elif method == "gcav" and self.gcav is not None:
raise NotImplementedError("Need features extraction")
else:
return self.concept_set.score_matrix(programs)
def verify_alignment(
self,
programs: list[Program],
ground_truth_clusters: Optional[list[list[int]]] = None,
) -> dict[str, float]:
"""
Verify that concept-space projection aligns with meaningful differences.
Checks:
1. Concept scores discriminate between functional clusters
2. Silhouette score in concept space
3. Concept dimensions are not redundant (low correlation)
"""
projection = self.project(programs)
results: dict[str, float] = {}
# 1. Check concept dimension independence
if projection.shape[1] > 1:
corr_matrix = np.corrcoef(projection.T)
# Average off-diagonal absolute correlation
n = corr_matrix.shape[0]
mask = ~np.eye(n, dtype=bool)
avg_correlation = np.abs(corr_matrix[mask]).mean()
results["avg_concept_correlation"] = float(avg_correlation)
results["n_concepts"] = n
# 2. Concept coverage (fraction of programs scored >0 on each concept)
coverage = (projection > 0).mean(axis=0)
results["mean_concept_coverage"] = float(coverage.mean())
results["min_concept_coverage"] = float(coverage.min())
results["max_concept_coverage"] = float(coverage.max())
# 3. Effective dimensionality (how many concepts are actually used)
variance_explained = projection.var(axis=0)
total_var = variance_explained.sum()
if total_var > 0:
normalized_var = variance_explained / total_var
effective_dim = float(np.exp(-np.sum(
normalized_var * np.log(normalized_var + 1e-10)
)))
results["effective_dimensionality"] = effective_dim
return results
def visualize_2d(
self,
programs: list[Program],
method: str = "tsne",
color_by: str = "functional_cluster",
) -> dict[str, Any]:
"""
Generate 2D visualization data for the concept space.
Returns coordinates and metadata suitable for plotting.
"""
projection = self.project(programs)
if method == "tsne":
from sklearn.manifold import TSNE
reducer = TSNE(n_components=2, random_state=42, perplexity=min(30, len(programs) - 1))
elif method == "umap":
try:
from umap import UMAP
reducer = UMAP(n_components=2, random_state=42)
except ImportError:
from sklearn.manifold import TSNE
reducer = TSNE(n_components=2, random_state=42, perplexity=min(30, len(programs) - 1))
elif method == "pca":
from sklearn.decomposition import PCA
reducer = PCA(n_components=2)
else:
raise ValueError(f"Unknown method: {method}")
coords_2d = reducer.fit_transform(projection)
# Color assignment
colors = []
if color_by == "functional_cluster":
sig_to_color = {}
color_idx = 0
for p in programs:
sig = p.functional_signature
if sig not in sig_to_color:
sig_to_color[sig] = color_idx
color_idx += 1
colors.append(sig_to_color[sig])
elif color_by == "model":
model_to_color = {}
color_idx = 0
for p in programs:
if p.model_id not in model_to_color:
model_to_color[p.model_id] = color_idx
color_idx += 1
colors.append(model_to_color[p.model_id])
return {
"x": coords_2d[:, 0].tolist(),
"y": coords_2d[:, 1].tolist(),
"colors": colors,
"program_ids": [p.program_id for p in programs],
"concept_names": self.concept_set.names,
"method": method,
}