This project implements the Sparse Crosscoder approach described in "Sparse Crosscoders for Cross-Layer Features and Model Diffing" for analyzing circuits within small language models like Qwen 2.5-3B.
Sparse Crosscoders enable us to understand how language models process information internally by:
- Identifying interpretable features across different layers of the model
- Tracking how these features interact to form computational circuits
- Visualizing attribution graphs showing information flow within the model
Unlike traditional approaches that analyze each layer independently, Crosscoders can discover features that span multiple layers, revealing how complex computations are distributed throughout the model.
The Cross-Layer Transcoder is a sparse autoencoder that reads from one layer and writes to subsequent layers. It learns interpretable features that can track how information flows through the model.
Key equations:
- Feature Activation:
aₗ = JumpReLU(W_enc_l * xₗ)
- Layer Reconstruction:
ŷₗ = ∑_{l'=1}^l W_dec_l'→l * aₗ'
Attribution graphs reveal how features influence each other and ultimately contribute to the model's output:
- Nodes: Individual features
- Edges: Attributions (strength of influence) between features
- Attribution calculation:
A_s→t := a_s * w_s→t
# Clone the repository
git clone https://github.com/username/circuit-tracing.git
cd circuit-tracing
# Install dependencies
pip install torch transformers matplotlib networkx numpy tqdm
python example_usage.py --mode train \
--model_name Qwen/Qwen2.5-3B \
--clt_path ./models/clt_qwen25_3b \
--train_samples 1000 \
--data_path ./data/training_text.txt
python circuit_tracing_example.py --mode train --model_name Qwen/Qwen2.5-3B --clt_path models/clt_qwen25_3b --train_samples 25 --data_path training/demo/demo_training_data_25.txt
python example_usage.py --mode analyze \
--model_name Qwen/Qwen2.5-3B \
--clt_path ./models/clt_qwen25_3b \
--prompt "The capital of France is"
python circuit_tracing_example.py --mode analyze --model_name Qwen/Qwen2.5-3B --clt_path models/clt_qwen25_3b --prompt "The capital of France is"
python debug_circuit_tracing.py --model_name Qwen/Qwen2.5-3B --clt_path models/clt_qwen25_3b --prompt "The capital of France is" --timeout 600 --reduced_features 5
python feature_analysis.py \
--model_name Qwen/Qwen2.5-3B \
--clt_path ./models/clt_qwen25_3b \
--prompts_file ./data/analysis_prompts.txt \
--output_dir ./results/feature_analysis \
--feature_idx 42 # Optional: analyze a specific feature
python feature_analysis_tool.py --model_name Qwen/Qwen2.5-3B --clt_path ./models/clt_qwen25_3b --prompts_file ./data/analysis_prompts.txt --output_dir ./results/feature_analysis --feature_idx 42 # Optional: analyze a specific feature #1627
circuit_tracing.py
: Core implementation of the Cross-Layer Transcoder and attribution analysisexample_usage.py
: Example script demonstrating how to train a CLT and analyze promptsfeature_analysis.py
: Tools for analyzing and interpreting specific features
-
Training the Cross-Layer Transcoder:
- Collect activations from the model's residual stream and MLP outputs
- Train the CLT to reconstruct MLP outputs using sparse feature activations
- The learned features represent interpretable concepts and computations
-
Analyzing Prompts:
- Run the model on a prompt and record activations
- Use the CLT to extract feature activations
- Compute attributions between features
- Generate and visualize the attribution graph
-
Interpreting Features:
- Identify top-activating prompts for each feature
- Analyze how features influence the model's predictions
- Group features into "supernodes" representing higher-level concepts
- JumpReLU Activation: Improves feature sparsity and interpretability
- Layer Normalization: Ensures stable training across different layers
- Sparsity Penalty: Encourages the development of sparse, interpretable features
- Virtual Weights: Capture context-independent interactions between features
- Number of Features: Adjust based on model size and complexity
- Sparsity Coefficient: Control the trade-off between reconstruction accuracy and feature interpretability
- Layer Selection: Focus on specific layers of interest
- Attribution Thresholds: Filter attributions to control graph complexity
When analyzing prompts like "The capital of France is", the attribution graph might show:
- Features representing "capital" and "France" concepts
- Their influence on features representing "Paris"
- The flow of information to the final logit prediction
- Computation Intensive: Training a CLT requires significant computational resources
- Approximations: The attribution calculation involves approximations that may not perfectly represent the model's computation
- Interpretability Challenges: Not all features may have clear interpretations
- Limited to MLPs: Current implementation focuses on MLP computations, not attention mechanisms
- Attention Integration: Extend to model attention mechanisms
- Feature Clustering: Develop automated methods for grouping related features
- Circuit Extraction: Automatically identify common circuit motifs
- Safety Analysis: Apply to studying safety mechanisms in aligned models
- "Sparse Crosscoders for Cross-Layer Features and Model Diffing", Anthropic, October 2024
- "Towards Monosemanticity: Decomposing Language Models With Dictionary Learning", Bricken et al., 2023
- "Jumping ahead: Improving reconstruction fidelity with jumprelu sparse autoencoders", Rajamanoharan et al., 2024
This project is licensed under the MIT License - see the LICENSE file for details.