Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ jobs:
- name: CmdStan installation cacheing
id: cache-cmdstan
if: ${{ !startswith(needs.get-cmdstan-version.outputs.version, 'git:') }}
uses: actions/cache@v4
uses: actions/cache@v5
with:
path: ~/.cmdstan
key: ${{ runner.os }}-cmdstan-${{ needs.get-cmdstan-version.outputs.version }}-${{ hashFiles('**/install_cmdstan.py') }}
Expand Down
11 changes: 9 additions & 2 deletions cmdstanpy/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1040,7 +1040,10 @@ def generate_quantities(
),
):
fit_object = previous_fit
fit_csv_files = previous_fit.runset.csv_files
if isinstance(previous_fit, CmdStanPathfinder):
fit_csv_files = [previous_fit.csv_file]
else:
fit_csv_files = previous_fit.runset.csv_files
elif isinstance(previous_fit, list):
if len(previous_fit) < 1:
raise ValueError(
Expand Down Expand Up @@ -1553,7 +1556,11 @@ def pathfinder(
' '.join(runset.cmd(0)), runset.get_err_msgs()
)
raise RuntimeError(msg)
return CmdStanPathfinder(runset)
return CmdStanPathfinder.from_files(
csv_file=runset.csv_files[0],
config_file=runset.config_files[0],
stdout_file=runset.stdout_files[0],
)

def log_prob(
self,
Expand Down
29 changes: 14 additions & 15 deletions cmdstanpy/stanfit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
CmdStanArgs,
LaplaceArgs,
OptimizeArgs,
PathfinderArgs,
SamplerArgs,
VariationalArgs,
)
Expand Down Expand Up @@ -256,21 +255,21 @@ def from_csv(
)
return CmdStanLaplace(runset, mode=mode)
elif config_dict['method'] == 'pathfinder':
pathfinder_args = PathfinderArgs(
num_draws=config_dict['num_draws'], # type: ignore
num_paths=config_dict['num_paths'], # type: ignore
)
cmdstan_args = CmdStanArgs(
model_name=model,
model_exe=model,
chain_ids=None,
method_args=pathfinder_args,
if len(csvfiles) != 1:
raise ValueError(
'Expecting a single Pathfinder Stan CSV file, '
f'found {len(csvfiles)}'
)
csv_file = csvfiles[0]
config_file = os.path.splitext(csv_file)[0] + '_config.json'
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I assume this is temporary but it's worth putting a comment saying so, since this is kind of a nasty assumption

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah definitely temporary. I was planning on doing away with this function entirely once the full changes are in.

if not os.path.exists(config_file):
raise ValueError(
'Pathfinder config file not found at expected path: '
f'{config_file}'
)
return CmdStanPathfinder.from_files(
csv_file=csv_file, config_file=config_file
)
runset = RunSet(args=cmdstan_args)
runset._csv_files = csvfiles
for i in range(len(runset._retcodes)):
runset._set_retcode(i, 0)
return CmdStanPathfinder(runset)
else:
get_logger().warning(
'Unable to process CSV output files from method %s.',
Expand Down
31 changes: 11 additions & 20 deletions cmdstanpy/stanfit/gq.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,8 @@
from __future__ import annotations

from collections import Counter
from typing import (
Any,
Generic,
Hashable,
MutableMapping,
NoReturn,
TypeVar,
overload,
)
from collections.abc import Hashable
from typing import Any, Generic, MutableMapping, NoReturn, TypeVar, overload

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -239,7 +232,7 @@ def draws(
drop_cols: list[int] = []
for dup in dups:
drop_cols.extend(
self.previous_fit._metadata.stan_vars[dup].columns()
self.previous_fit.metadata.stan_vars[dup].columns()
)

start_idx, _ = self._draws_start(inc_warmup)
Expand Down Expand Up @@ -333,10 +326,8 @@ def draws_pd(
gq_cols.extend(
self.column_names[info.start_idx : info.end_idx]
)
elif (
inc_sample and var in self.previous_fit._metadata.stan_vars
):
info = self.previous_fit._metadata.stan_vars[var]
elif inc_sample and var in self.previous_fit.metadata.stan_vars:
info = self.previous_fit.metadata.stan_vars[var]
mcmc_vars.extend(
self.previous_fit.column_names[
info.start_idx : info.end_idx
Expand Down Expand Up @@ -472,7 +463,7 @@ def draws_xr(
for var in vars_list:
if var not in self._metadata.stan_vars:
if inc_sample and (
var in self.previous_fit._metadata.stan_vars
var in self.previous_fit.metadata.stan_vars
):
mcmc_vars_list.append(var)
dup_vars.append(var)
Expand All @@ -481,7 +472,7 @@ def draws_xr(
else:
vars_list = list(self._metadata.stan_vars.keys())
if inc_sample:
for var in self.previous_fit._metadata.stan_vars.keys():
for var in self.previous_fit.metadata.stan_vars.keys():
if var not in vars_list and var not in mcmc_vars_list:
mcmc_vars_list.append(var)
for var in dup_vars:
Expand All @@ -490,7 +481,7 @@ def draws_xr(
self._assemble_generated_quantities()

num_draws = self.previous_fit.num_draws_sampling
sample_config = self.previous_fit._metadata.cmdstan_config
sample_config = self.previous_fit.metadata.cmdstan_config
attrs: MutableMapping[Hashable, Any] = {
"stan_version": f"{sample_config['stan_version_major']}."
f"{sample_config['stan_version_minor']}."
Expand Down Expand Up @@ -518,7 +509,7 @@ def draws_xr(
for var in mcmc_vars_list:
build_xarray_data(
data,
self.previous_fit._metadata.stan_vars[var],
self.previous_fit.metadata.stan_vars[var],
self.previous_fit.draws(inc_warmup=inc_warmup),
)

Expand Down Expand Up @@ -570,7 +561,7 @@ def stan_variable(self, var: str, **kwargs: bool) -> np.ndarray:
CmdStanVB.stan_variable
CmdStanLaplace.stan_variable
"""
model_var_names = self.previous_fit._metadata.stan_vars.keys()
model_var_names = self.previous_fit.metadata.stan_vars.keys()
gq_var_names = self._metadata.stan_vars.keys()
if not (var in model_var_names or var in gq_var_names):
raise ValueError(
Expand Down Expand Up @@ -611,7 +602,7 @@ def stan_variables(self, **kwargs: bool) -> dict[str, np.ndarray]:
CmdStanLaplace.stan_variables
"""
result = {}
sample_var_names = self.previous_fit._metadata.stan_vars.keys()
sample_var_names = self.previous_fit.metadata.stan_vars.keys()
gq_var_names = self._metadata.stan_vars.keys()
for name in gq_var_names:
result[name] = self.stan_variable(name, **kwargs)
Expand Down
165 changes: 163 additions & 2 deletions cmdstanpy/stanfit/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,19 @@
from __future__ import annotations

import copy
import json
import math
import os
from typing import Any, Iterator, Literal
from typing import Annotated, Any, Iterator, Literal

import stanio
from pydantic import BaseModel, field_validator, model_validator
from pydantic import (
BaseModel,
ConfigDict,
Discriminator,
field_validator,
model_validator,
)

from cmdstanpy.utils import stancsv

Expand Down Expand Up @@ -125,3 +132,157 @@ def validate_inv_metric_shape(self) -> MetricInfo:
raise ValueError("Dense inv_metric must be square")

return self


class SampleConfig(BaseModel):
model_config = ConfigDict(extra="allow")

method: Literal["sample"] = "sample"
algorithm: str
num_samples: int
num_warmup: int
save_warmup: bool = False
thin: int = 1
max_depth: int | None = None


class OptimizeConfig(BaseModel):
model_config = ConfigDict(extra="allow")

method: Literal["optimize"] = "optimize"
algorithm: str
save_iterations: bool = False
jacobian: bool = False


class PathfinderConfig(BaseModel):
model_config = ConfigDict(extra="allow")

method: Literal["pathfinder"] = "pathfinder"
num_draws: int = 1000
num_paths: int = 4
psis_resample: bool = True
calculate_lp: bool = True


class LaplaceConfig(BaseModel):
model_config = ConfigDict(extra="allow")

method: Literal["laplace"] = "laplace"
mode: str
draws: int = 1000
jacobian: bool = True


class VariationalConfig(BaseModel):
model_config = ConfigDict(extra="allow")

method: Literal["variational"] = "variational"
algorithm: str
iter: int = 10000
grad_samples: int = 1
elbo_samples: int = 100
eta: float = 1.0
tol_rel_obj: float = 0.01
eval_elbo: int = 100
output_samples: int = 1000


class GeneratedQuantitiesConfig(BaseModel):
model_config = ConfigDict(extra="allow")

method: Literal["generate_quantities"] = "generate_quantities"
fitted_params: str
num_chains: int = 1


class StanConfig(BaseModel):
"""Common representation of a config JSON file output as part of a
Stan inference run. Separate method-specific config classes handle
the variation of output between methods."""

model_config = ConfigDict(extra="allow")

model_name: str
stan_major_version: str
stan_minor_version: str
stan_patch_version: str

method_config: Annotated[
SampleConfig
| OptimizeConfig
| PathfinderConfig
| LaplaceConfig
| VariationalConfig
| GeneratedQuantitiesConfig,
Discriminator("method"),
]


def flatten_value_dict(data: dict[str, Any]) -> dict[str, Any]:
"""Recursively flatten CmdStan's nested value/subdict structure.

CmdStan uses a pattern where a field contains:
{"value": "val", "val": {"k1": v1, "k2": v2}}

This flattens it to the parent level as:
{"field": "val", "k1": v1, "k2": v2}

The flattening is applied recursively to any nested dicts.
"""
result: dict[str, Any] = {}

for key, val in data.items():
if not isinstance(val, dict):
result[key] = val
continue

if "value" in val:
value_name = val['value']
result[key] = value_name

# Get the nested dict matching the value name and flatten it
nested = val.get(value_name, {})
if isinstance(nested, dict):
flattened_nested = flatten_value_dict(nested)
for nested_key, nested_val in flattened_nested.items():
if nested_key not in result:
result[nested_key] = nested_val
else:
# Regular dict without value pattern - recurse into it
result[key] = flatten_value_dict(val)

return result


def flatten_config(data: dict[str, Any]) -> dict[str, Any]:
"""Flatten nested CmdStan config JSON structure.

CmdStan outputs config JSON with deeply nested structure like:
{"method": {"value": "sample", "sample": {"num_samples": 1000, ...}}}

This flattens it to:
{"method_config": {"method": "sample", "num_samples": 1000, ...}, ...}
"""
method_data = data.get('method')
if not isinstance(method_data, dict):
return data

result = {k: v for k, v in data.items() if k != "method"}
method_name = method_data.get('value')

# Build method_config from the method-specific nested dict
nested_method = method_data.get(method_name, {})
method_config = flatten_value_dict(nested_method)
method_config['method'] = method_name

result['method_config'] = method_config
return result


def parse_config(json_data: str | bytes) -> StanConfig:
"""Parse a CmdStan config JSON string into a StanConfig."""

raw = json.loads(json_data)
flat = flatten_config(raw)
return StanConfig.model_validate(flat) # type: ignore
Loading