diff --git a/.gitmodules b/.gitmodules deleted file mode 100644 index 9ca4ddb..0000000 --- a/.gitmodules +++ /dev/null @@ -1,4 +0,0 @@ -[submodule "lib/olmo-eval"] - path = lib/olmo-eval - url = https://github.com/allenai/OLMo-Eval.git - ignore = untracked \ No newline at end of file diff --git a/README.md b/README.md index a5ccb3b..6121e69 100644 --- a/README.md +++ b/README.md @@ -149,7 +149,7 @@ If you use Pico in your research, please cite: ```bibtex @software{pico2024, - author = {Martinez, Richard Diehl}, + author = {Diehl Martinez, Richard}, title = {Pico: Framework for Training Tiny Language Models}, year = {2024}, } diff --git a/config.py b/config.py index 8972b58..3ac813d 100644 --- a/config.py +++ b/config.py @@ -160,9 +160,9 @@ class TrainingConfig: @dataclass -class _PalomaEvaluationConfig: - limit_eval_examples: Optional[int] = 1 +class PalomaEvaluationConfig: max_length: int = MAX_SEQ_LEN + batch_size: int = 16 @dataclass @@ -182,4 +182,4 @@ class EvaluationConfig: # NOTE: Add other evaluation configs here # Each evaluation metric should have its own config - paloma: _PalomaEvaluationConfig = field(default_factory=_PalomaEvaluationConfig) + paloma: PalomaEvaluationConfig = field(default_factory=PalomaEvaluationConfig) diff --git a/lib/olmo-eval b/lib/olmo-eval deleted file mode 160000 index 51c5ba5..0000000 --- a/lib/olmo-eval +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 51c5ba579e75ef4ce7e9b29936eaa72c1a0e99eb diff --git a/model.py b/model.py index 65ddb04..c8e2f55 100644 --- a/model.py +++ b/model.py @@ -44,6 +44,7 @@ from transformers import PretrainedConfig, PreTrainedModel +from transformers.modeling_outputs import CausalLMOutputWithPast, CausalLMOutput ######################################################## # @@ -189,6 +190,8 @@ def forward( # otherwise, we need to move it to the correct device if self.fabric is not None: freqs_cis = self.fabric.to_device(freqs_cis) + else: + freqs_cis = freqs_cis.to(queries.device) queries_rotated = torch.view_as_real(queries_ * freqs_cis).flatten(3) keys_rotated = torch.view_as_real(keys_ * freqs_cis).flatten(3) @@ -577,8 +580,17 @@ def forward( past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, use_cache: bool = False, **kwargs, - ) -> Tuple[torch.Tensor, Optional[Tuple[Tuple[torch.Tensor]]]]: - return self.pico(input_ids, past_key_values, use_cache) + ) -> Union[CausalLMOutput, CausalLMOutputWithPast]: + logits, past_key_values = self.pico(input_ids, past_key_values, use_cache) + if use_cache: + return CausalLMOutputWithPast( + logits=logits, + past_key_values=past_key_values, + ) + else: + return CausalLMOutput( + logits=logits, + ) # Register for auto classes diff --git a/pyproject.toml b/pyproject.toml index 25aae49..54a90ab 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,17 +9,15 @@ readme = "README.md" [tool.poetry.dependencies] python = "^3.10" -ray = "^2.35.0" lightning = "^2.4.0" -omegaconf = "^2.3.0" click = "^8.1.7" wandb = "^0.18.1" huggingface-hub = {extras = ["cli"], version = "^0.25.1"} -jsonnet = "^0.20.0" -virtualenv = "^20.27.1" datasets = "^3.0.1" transformers = "^4.45.2" pre-commit = "^4.0.1" +torch = "^2.5.1" +evaluate = "^0.4.3" [tool.poetry.group.dev.dependencies] ipykernel = "^6.29.5" diff --git a/setup.sh b/setup.sh index 7becc64..4f9093b 100755 --- a/setup.sh +++ b/setup.sh @@ -28,10 +28,11 @@ print_warning() { echo -e "${YELLOW}⚠ $1${NC}" } -# Check if git-lfs is installed +# --- GIT LFS SETUP --- # print_section "Git LFS Setup" if ! command -v git-lfs &> /dev/null; then print_warning "git-lfs is not installed. Some model checkpointing functionality may not work correctly." + ERRORS_FOUND=$((ERRORS_FOUND + 1)) # Check the operating system if [[ "$OSTYPE" == "darwin"* ]]; then @@ -60,7 +61,7 @@ else print_success "git-lfs installed and initialized" fi -# Check CUDA version +# --- CUDA VERSION CHECK --- # print_section "CUDA Version Check" if command -v nvidia-smi &> /dev/null; then CUDA_VERSION=$(nvidia-smi | sed -n 's/.*CUDA Version: \([0-9.]*\).*/\1/p') @@ -77,7 +78,7 @@ if command -v nvidia-smi &> /dev/null; then ERRORS_FOUND=$((ERRORS_FOUND + 1)) print_warning "CUDA version ${MAJOR_VERSION}.${MINOR_VERSION} detected." echo -e "${YELLOW} Some multi-node communication GPU features may not work properly.${NC}" - echo -e "${YELLOW} CUDA version 12.1 or newer is required.${NC}" + echo -e "${YELLOW} CUDA version 12.1 or newer is recommended.${NC}" else print_success "CUDA version ${MAJOR_VERSION}.${MINOR_VERSION} detected" fi @@ -88,11 +89,6 @@ else echo -e "${YELLOW} Ensure that NVIDIA drivers and CUDA version at 12.1 or newer are installed for GPU support.${NC}" fi -# Initialize and update git submodules -print_section "Git Submodules" -echo "Initializing git submodules..." -git submodule update --init --recursive -print_success "Git submodules initialized" # ---- ENVIRONMENT VARIABLES ---- # print_section "Environment Variables" @@ -105,6 +101,7 @@ else echo -e "${YELLOW} Example .env contents:${NC}" echo " export HF_TOKEN=your_huggingface_token" echo " export WANDB_API_KEY=your_wandb_key" + ERRORS_FOUND=$((ERRORS_FOUND + 1)) fi # ---- POETRY SETUP ---- # @@ -132,15 +129,6 @@ fi # ---- PRE-COMMIT SETUP ---- # print_section "Pre-commit Setup" -# First check if pre-commit is installed in the poetry environment -if ! poetry run pre-commit --version &> /dev/null; then - echo "Installing pre-commit in poetry environment..." - poetry add pre-commit --group dev - print_success "pre-commit installed successfully" -else - print_success "pre-commit already installed" -fi - # Install pre-commit hooks echo "Installing pre-commit hooks..." poetry run pre-commit install @@ -151,58 +139,7 @@ echo "Running pre-commit hooks on all files..." poetry run pre-commit run --all-files print_success "Pre-commit initial run complete" -# ---- EVALUATION SETUP ---- # -print_section "Evaluation (Paloma) Setup" - -# Add flag check for skipping evaluation -if [ "$1" = "--skip-eval" ]; then - print_warning "Skipping evaluation setup as requested" -else - if [ ! -d "lib/paloma" ]; then - if [ ! -z "$HF_TOKEN" ]; then - echo "Setting up HuggingFace authentication..." - echo $HF_TOKEN | poetry run huggingface-cli login --token $HF_TOKEN - - echo "Cloning Paloma evaluation dataset..." - git clone https://oauth2:${HF_TOKEN}@huggingface.co/datasets/allenai/paloma lib/paloma - - if [ $? -eq 0 ]; then - print_success "Paloma dataset cloned successfully" - else - ERRORS_FOUND=$((ERRORS_FOUND + 1)) - print_warning "Failed to clone Paloma dataset" - echo -e "${YELLOW} Please verify your HuggingFace token has correct permissions${NC}" - echo -e "${YELLOW} Make sure you have been granted access to allenai/paloma dataset${NC}" - rm -rf lib/paloma - fi - else - print_warning "Skipping Paloma dataset clone. HuggingFace credentials not found." - echo -e "${YELLOW} You need to request access to the Paloma dataset on HuggingFace:${NC}" - echo -e " ${BLUE}https://huggingface.co/datasets/allenai/paloma${NC}" - echo -e "${YELLOW} Visit the dataset page and click 'Access Request' to request permission.${NC}" - rm -rf lib/paloma - fi - else - print_success "Paloma dataset already exists, skipping clone" - fi - - # Create environment for running evaluation inside of lib/olmo_eval - if [ ! -d "lib/olmo-eval/env" ]; then - print_section "OLMo Eval Setup" - poetry run bash -c ' - cd lib/olmo-eval - echo "Creating virtual environment..." - virtualenv env - source env/bin/activate - pip install --python-version 3.10 -e . - deactivate - cd ../../ - echo "OLMo eval environment setup complete" - ' - else - print_success "OLMo eval environment already exists, skipping setup" - fi -fi +# --- Final Status Message --- # # Final status message print_section "Setup Status" diff --git a/utils/evaluation.py b/utils/evaluation.py index 41acc9c..5cee2b1 100644 --- a/utils/evaluation.py +++ b/utils/evaluation.py @@ -6,7 +6,9 @@ evaluation workflow. NOTE: out of the box we only support Paloma, but the structure is designed to be flexible and -you are meant to add whatever metrics you want. +you are meant to add whatever metrics you want. One of the main reasons we store out +the model in the HuggingFace format is so that its easy to use third-party evaluation +libraries/frameworks. Main Workflow: 1. Setup evaluation configuration @@ -15,16 +17,11 @@ 4. Clean up temporary files and workspaces """ -import os -import _jsonnet -import tempfile -import subprocess -import shutil -from pathlib import Path -import json -import gzip +from datasets import load_dataset +import evaluate -from config import EvaluationConfig +import os +from config import EvaluationConfig, PalomaEvaluationConfig from . import RUNS_DIR, CHECKPOINT_DIR @@ -39,196 +36,67 @@ Paloma is a comprehensive evaluation benchmark for large language models (LLMs) that focuses on measuring perplexity across diverse text domains. -To evaluate on Paloma, we use the olmo-eval library, which provides a unified interface for -evaluating models on a variety of benchmarks. +To evaluate on Paloma, we use the huggingface evaluation framework. For more details, see: https://huggingface.co/datasets/allenai/paloma """ -TEMP_EVAL_RESULTS_DIR = "_temp_paloma_results" -PPL_METRICS_FILE = "ppl_metrics.jsonl.gz" -EVAL_DATA_PATH = "lib/paloma" -EVAL_LIB_DIR = "lib/olmo-eval" - -# NOTE: the jsonnet template is what is used by the olmo-eval library to run the evaluation. -jsonnet_template = """ - local utils = import 'lib/olmo-eval/configs/utils.libsonnet'; - local ppl_suite = import 'lib/olmo-eval/configs/task_sets/paloma_hf_release_val.libsonnet'; - - local gsheet = null; - local output_dir = std.extVar('output_dir'); - - local model_path = std.extVar('model_path'); - local max_length = std.parseInt(std.extVar('max_length')); - local limit = std.parseInt(std.extVar('limit')); - - local model = { - model_path: model_path, - revision: null, - gpus_needed: 1, - trust_remote_code: true, - prediction_kwargs: { - model_max_length: max_length, - limit: limit, - } - }; - - local task_sets = [ - ppl_suite.task_set - ]; - - { - steps: utils.create_fine_grained_pipeline([model], task_sets, gsheet, output_dir) - } -""" - - -def setup_paloma_config(model_path: str, evaluation_config: EvaluationConfig) -> str: - """Create Paloma config from evaluation configuration. - - This function generates a Jsonnet configuration file for Paloma evaluation by: - 1. Setting up the output directory structure - 2. Configuring model-specific parameters (max length, example limits) - 3. Applying the configuration template for the Paloma evaluation suite +PALOMA_SUB_CONFIGS = [ + "4chan_meta_sep", + "c4_100_domains", + "c4_en", + "dolma_100_programing_languages", + "dolma_100_subreddits", + "dolma-v1_5", + "falcon-refinedweb", + "gab", + "m2d2_s2orc_unsplit", + "m2d2_wikipedia_unsplit", + "manosphere_meta_sep", + "mc4", + "ptb", + "redpajama", + "twitterAAE_HELM_fixed", + "wikitext_103", +] + + +def run_paloma_evaluation( + model_path: str, paloma_config: PalomaEvaluationConfig +) -> None: + """Run Perplexity evaluation on the Paloma evaluation dataset. We use the HuggingFace + evaluate library to load in and compute the perplexity metric. Args: model_path (str): Path to the model checkpoint to be evaluated - evaluation_config (EvaluationConfig): Configuration object containing: - - run_name (str): Name of the evaluation run - - paloma.max_length (int): Maximum sequence length for evaluation - - paloma.limit_eval_examples (Optional[int]): Number of examples to evaluate - (None for full evaluation) - - Returns: - str: Path to the generated temporary Jsonnet configuration file - - Example: - config_path = setup_paloma_config( - model_path="/checkpoints/model-1000", - evaluation_config=EvaluationConfig( - run_name="experiment_1", - paloma=PalomaConfig( - max_length=2048, - limit_eval_examples=100 - ) - ) - ) - - Note: - The generated config uses the Paloma evaluation suite's standard template - with custom parameters for the specific model evaluation run. + paloma_config (PalomaEvaluationConfig): Configuration for Paloma evaluation """ - # Convert evaluation config to external vars - output_dir = ( - f"{os.getcwd()}/{RUNS_DIR}/{evaluation_config.run_name}/{TEMP_EVAL_RESULTS_DIR}" - ) - - # create output dir - os.makedirs(output_dir, exist_ok=True) - - ext_vars = { - "output_dir": output_dir, - "model_path": model_path, - "max_length": str(evaluation_config.paloma.max_length), - "limit": "null" - if evaluation_config.paloma.limit_eval_examples is None - else str(evaluation_config.paloma.limit_eval_examples), - } - - # Evaluate template with overrides - json_str = _jsonnet.evaluate_snippet("config", jsonnet_template, ext_vars=ext_vars) - - # Write to temporary file - temp_config = tempfile.NamedTemporaryFile(mode="w", suffix=".jsonnet", delete=False) - temp_config.write(json_str) - temp_config.close() - - return temp_config.name - - -def run_paloma_evaluation(paloma_config_path: str) -> None: - """Run Paloma evaluation using the Tango framework. - - This function executes the evaluation process for the Paloma benchmark by: - 1. Activating the virtual environment for the olmo-eval library - 2. Running the Tango command to perform evaluation based on the provided config - 3. Managing temporary workspaces and cleaning up after execution - - Args: - paloma_config_path (str): Path to the Jsonnet configuration file for Paloma evaluation - - Note: - Ensure that the environment is correctly set up with all dependencies - before running this function. The function uses bash to execute commands. - """ - olmo_eval_dir = Path(EVAL_LIB_DIR) - venv_activate = "env/bin/activate" - tmp_workspace_name = "pico-tmp-eval-ws" - - # Construct the command with source activation - cmd = f"source {venv_activate} && tango --settings tango.yml run {paloma_config_path} --workspace {tmp_workspace_name}" - - try: - subprocess.run( - cmd, - cwd=olmo_eval_dir, - shell=True, # Required for source command - executable="/bin/bash", # Ensure bash is used - check=True, - text=True, - capture_output=False, - env=os.environ.copy(), + # Load in custom perplexity metric (this is just a fork of the normal perplexity metric + # that makes it possible to pass `trust_remote_code=True` to the `compute` method) + perplexity = evaluate.load("pico-lm/perplexity") + + perplexity_results = {} + perplexity_counts = {} + + for sub_config in PALOMA_SUB_CONFIGS: + dataset = load_dataset("allenai/paloma", sub_config, split="val")["text"] + perplexity_result = perplexity.compute( + model_id=model_path, + predictions=dataset, + add_start_token=False, + max_length=paloma_config.max_length, + batch_size=paloma_config.batch_size, + trust_remote_code=True, ) - except subprocess.CalledProcessError as e: - raise RuntimeError(f"Model evaluation failed: {e}") + perplexity_results[sub_config] = perplexity_result["mean_perplexity"] + perplexity_counts[sub_config] = len(dataset) - # Delete workspace cache - shutil.rmtree(f"{os.getcwd()}/{EVAL_LIB_DIR}/{tmp_workspace_name}") - - -def process_tango_output(evaluation_config: EvaluationConfig) -> None: - """Process and aggregate the results from Paloma evaluation. - - This function handles the post-processing of Tango evaluation outputs by: - 1. Loading the compressed metrics file from the evaluation directory - 2. Processing the JSONL format containing per-example perplexity scores - 3. Computing the average perplexity across all evaluated examples - 4. Cleaning up temporary evaluation files - - Args: - evaluation_config (EvaluationConfig): Configuration object containing: - - run_name (str): Name of the evaluation run used for directory paths - - Returns: - float: Average perplexity score across all evaluated examples - - File Structure: - The function expects results in: - {RUNS_DIR}/{run_name}/{TEMP_EVAL_RESULTS_DIR}/ppl_metrics.jsonl.gz - - Note: - This function automatically cleans up temporary evaluation files - after processing to maintain disk space efficiency. - """ - output_dir = ( - f"{os.getcwd()}/{RUNS_DIR}/{evaluation_config.run_name}/{TEMP_EVAL_RESULTS_DIR}" - ) - # load in ppl metrics - ppl_metrics_path = os.path.join(output_dir, PPL_METRICS_FILE) - - with gzip.open(ppl_metrics_path, "rt") as f: # 'rt' mode for text reading - # For a JSONL file, read line by line - ppl_metrics = [json.loads(line) for line in f] - - # average together the ppl_primary metrics - ppl_primary_metrics = [metric["ppl_primary"] for metric in ppl_metrics] - avg_ppl_primary = sum(ppl_primary_metrics) / len(ppl_primary_metrics) - - # delete ppl metrics file -- it would be too messy to keep around - shutil.rmtree(output_dir) - - return avg_ppl_primary + # return micro average perplexity + return sum( + perplexity_results[sub_config] * perplexity_counts[sub_config] + for sub_config in PALOMA_SUB_CONFIGS + ) / sum(perplexity_counts.values()) ######################################################## @@ -243,7 +111,7 @@ def run_evaluation(evaluation_config: EvaluationConfig) -> None: This function orchestrates the complete evaluation pipeline by: 1. Resolving the model checkpoint path (either specified or latest) - 2. Creating necessary directories and environment setup + 2. Running possible setup steps for the evaluation metric 3. Executing each requested evaluation metric 4. Aggregating results across all metrics @@ -270,7 +138,7 @@ def run_evaluation(evaluation_config: EvaluationConfig) -> None: EvaluationConfig( run_name="experiment_1", evaluation_metrics=["paloma"], - paloma=PalomaConfig(max_length=2048) + paloma=PalomaConfig(max_length=2048, batch_size=16) ) ) @@ -291,16 +159,10 @@ def run_evaluation(evaluation_config: EvaluationConfig) -> None: for metric in evaluation_config.evaluation_metrics: if metric == "paloma": - paloma_config_path = setup_paloma_config(model_path, evaluation_config) - - os.environ["EVAL_DATA_PATH"] = os.path.join( - os.getcwd(), "lib" - ) # No need for "paloma" - run_paloma_evaluation(paloma_config_path) - metric_result = process_tango_output(evaluation_config) + paloma_result = run_paloma_evaluation(model_path, evaluation_config.paloma) else: raise ValueError(f"Metric {metric} not supported") - evaluation_results[metric] = metric_result + evaluation_results[metric] = paloma_result return evaluation_results