Replies: 2 comments 4 replies
-
| Example for the radon modelimport pymc as pm
import numpy as np
import pandas as pd
# Load the radon dataset
data = pd.read_csv(pm.get_data("radon.csv"))
data["log_radon"] = data["log_radon"].astype(np.float64)
county_idx, counties = pd.factorize(data.county)
coords = {"county": counties, "obs_id": np.arange(len(county_idx))}
# Create a simple hierarchical model for the radon dataset
with pm.Model(coords=coords, check_bounds=False) as model:
    intercept = pm.Normal("intercept", sigma=10)
    # County effects
    # TODO should be a CenteredNormal
    raw = pm.Normal("county_raw", dims="county")
    sd = pm.HalfNormal("county_sd")
    county_effect = pm.Deterministic("county_effect", raw * sd, dims="county")
    # Global floor effect
    floor_effect = pm.Normal("floor_effect", sigma=2)
    # County:floor interaction
    # Should also be a CenteredNormal
    raw = pm.Normal("county_floor_raw", dims="county")
    sd = pm.HalfNormal("county_floor_sd")
    county_floor_effect = pm.Deterministic(
        "county_floor_effect", raw * sd, dims="county"
    )
    mu = (
        intercept
        + county_effect[county_idx]
        + floor_effect * data.floor.values
        + county_floor_effect[county_idx] * data.floor.values
    )
    sigma = pm.HalfNormal("sigma", sigma=1.5)
    pm.Normal(
        "log_radon", mu=mu, sigma=sigma, observed=data.log_radon.values, dims="obs_id"
    )
f = model.logp_dlogp_function()
toposort = f._aesara_function.vm.fgraph.toposort()
dependencies = f._aesara_function.vm.fgraph.orderings()
dependencies = {k: set(v) for k, v in dependencies.items()}
n_workers = 8
evaluations = [[None] * n_workers]
try:
    sched = scheduler(toposort, dependencies)
    while True:
        step_i = []
        pad = 0
        for i in range(n_workers):
            task = sched.send(evaluations[-1][i])
            if task is not None:
                step_i.append(task)
            else:
                pad += 1
        for i in range(pad):
            step_i.append(sched.send(None))
        evaluations.append(step_i)    
except StopIteration:
    pass | 
Beta Was this translation helpful? Give feedback.
                  
                    0 replies
                  
                
            -
| @ferrine IIUC this is only relevant for the C backend. JAX and NUMBA backends compile a single thunk/JIT graph that's evaluated as a monolith. JAX is also already multi-threaded internally. Since the goal is to deprecate the C backend I am not sure this is useful work. Did I misinterpret your idea? | 
Beta Was this translation helpful? Give feedback.
                  
                    4 replies
                  
                
            
  
    Sign up for free
    to join this conversation on GitHub.
    Already have an account?
    Sign in to comment
  
        
    
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
Was thinking a lot about execution schedulers. The implementation itself is few lines of code and relies on toposort with priorities (e.g. compute intensity)
I was thinking quite a lot about this this day. Given we can get thunks and their dependency graph, we can apply this kind of scheduling where we can have multiple workers in different threads. They contribute with the results and scheduler decides on the next task.
awaitinglist allows to keep long running jobs for some workers and eagerly fetch new tasks from the compute graph.Another use case for this scheduler is to emulate workers and get nodes that could be computed independently in batches to see if we can fuse them
Beta Was this translation helpful? Give feedback.
All reactions