Skip to content

Commit c4e6ef9

Browse files
committed
diviner: first-class support for replicates
Summary: This change adds first-class support for replicates to Diviner. If a study is configured with replicates, each value combination is run that many times; their metrics are averaged. The study's run function is given the replicate number, so that this can be used, e.g., to select outer folds. Trials are constructed from all available replicates, and a trial is not considered complete until all of its replicates are computed. The leaderboard is computed using aggregated trials. Reviewers: ysaito, dfilippova Reviewed By: ysaito Subscribers: ayip, vnicula, onikolic, joshnewman Differential Revision: https://phabricator.grailbio.com/D29971 fbshipit-source-id: 41350eb
1 parent 2bcaf62 commit c4e6ef9

File tree

20 files changed

+863
-248
lines changed

20 files changed

+863
-248
lines changed

cmd/diviner/main.go

Lines changed: 57 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -458,8 +458,8 @@ the given study names.`)
458458
valueKeys := make(map[string]bool)
459459
for i := range studies {
460460
for _, run := range runs[i] {
461-
for name := range run.Values {
462-
valueKeys[name] = true
461+
for key := range run.Values {
462+
valueKeys[key] = true
463463
}
464464
}
465465
}
@@ -508,6 +508,7 @@ var (
508508
objective: {{.Objective}}{{range $_, $value := .Params.Sorted }}
509509
{{$value.Name}}: {{$value.Param}}{{end}}
510510
oracle: {{printf "%T" .Oracle}}
511+
replicates: {{.Replicates}}
511512
`))
512513

513514
runFuncMap = template.FuncMap{
@@ -521,6 +522,7 @@ var (
521522
created: {{.run.Created.Local}}
522523
runtime: {{.run.Runtime}}
523524
restarts: {{.run.Retries}}
525+
replicate: {{.run.Replicate}}
524526
values:{{range $_, $value := .run.Values.Sorted }}
525527
{{$value.Name}}: {{$value.Value}}{{end}}{{if .verbose}}{{range $index, $metrics := .run.Metrics}}
526528
metrics[{{$index}}]:{{range $_, $metric := $metrics.Sorted}}
@@ -658,9 +660,10 @@ column; missing values are denoted by "NA".`)
658660

659661
func run(db diviner.Database, args []string) {
660662
var (
661-
flags = flag.NewFlagSet("run", flag.ExitOnError)
662-
ntrials = flags.Int("trials", 1, "number of trials to run in each round")
663-
nrounds = flags.Int("rounds", 1, "number of rounds to run")
663+
flags = flag.NewFlagSet("run", flag.ExitOnError)
664+
ntrials = flags.Int("trials", 1, "number of trials to run in each round")
665+
nrounds = flags.Int("rounds", 1, "number of rounds to run")
666+
replicate = flags.Int("replicate", 0, "replicate to re-run")
664667
)
665668
flags.Usage = func() {
666669
fmt.Fprintln(os.Stderr, `usage: diviner run [-rounds n] [-trials n] script.dv [studies]
@@ -760,7 +763,7 @@ If no studies are specified, all defined studies are run concurrently.`)
760763
}
761764
err = traverse.Each(len(runs), func(i int) (err error) {
762765
log.Printf("repeating run %s", args[i])
763-
runs[i], err = runner.Run(ctx, runsStudy[i], runs[i].Values)
766+
runs[i], err = runner.Run(ctx, runsStudy[i], runs[i].Values, *replicate)
764767
return
765768
})
766769
if err != nil {
@@ -804,7 +807,9 @@ func showScript(db diviner.Database, args []string) {
804807
Script renders a bash script to standard output containing a function
805808
for each of the study's datasets and the study itself. The study's
806809
parameter values can be specified via flags; unspecified parameter
807-
values are sampled randomly from valid parameter values.`)
810+
values are sampled randomly from valid parameter values. The
811+
study is always invoked with replicate 0.
812+
`)
808813
flags.PrintDefaults()
809814
os.Exit(2)
810815
}
@@ -840,7 +845,7 @@ values are sampled randomly from valid parameter values.`)
840845
}
841846
var (
842847
study = matched[0]
843-
values = make(diviner.Values)
848+
values diviner.Values
844849
rng = rand.New(rand.NewSource(0))
845850
)
846851
flags = flag.NewFlagSet(study.Name, flag.ExitOnError)
@@ -878,7 +883,7 @@ values are sampled randomly from valid parameter values.`)
878883
log.Fatalf("value %s is not valid for parameter %s %s", val, name, p)
879884
}
880885
}
881-
config, err := study.Run(values, *ident)
886+
config, err := study.Run(values, 0, *ident)
882887
if err != nil {
883888
log.Fatal(err)
884889
}
@@ -895,11 +900,9 @@ func leaderboard(db diviner.Database, args []string) {
895900
flags = flag.NewFlagSet("leaderboard", flag.ExitOnError)
896901
objectiveOverride = flags.String("objective", "", "objective to use instead of studies' shared objective")
897902
numEntries = flags.Int("n", 10, "number of top trials to display")
898-
valuesRe = flags.String("values", "^$", "comma-separated list of anchored regular expression matching parameter values to display")
903+
valuesRe = flags.String("values", ".", "comma-separated list of anchored regular expression matching parameter values to display")
899904
metricsRe = flags.String("metrics", "^$", `comma-separated list of anchored regular expression matching additional metrics to display.
900905
Each regex can be prefixed with '+' or '-'. A regex with '+' (or '-'), when combined with -best, will pick the largest (or smallest) metric from each run.`)
901-
best = flags.Bool("best", false,
902-
"If set, pick the best metric from each run. Otherwise, pick the last metric reported.")
903906
)
904907
flags.Usage = func() {
905908
fmt.Fprintln(os.Stderr, `usage: diviner leaderboard [-objective objective] [-n N] [-values values] [-metrics metrics] studies...
@@ -954,16 +957,25 @@ specifying regular expressions for matching them via the flags
954957
} else {
955958
objective = parseObjective(*objectiveOverride)
956959
}
960+
type trial struct {
961+
diviner.Trial
962+
Study string
963+
}
957964
var (
958-
runsMu sync.Mutex
959-
runs []diviner.Run
965+
trialsMu sync.Mutex
966+
trials []trial
960967
)
961968
err := traverser.Each(len(studies), func(i int) error {
962-
r, err := db.ListRuns(ctx, studies[i].Name, diviner.Success, time.Time{})
963-
runsMu.Lock()
964-
runs = append(runs, r...)
965-
runsMu.Unlock()
966-
return err
969+
t, err := diviner.Trials(ctx, db, studies[i])
970+
if err != nil {
971+
return err
972+
}
973+
trialsMu.Lock()
974+
t.Range(func(_ diviner.Value, v interface{}) {
975+
trials = append(trials, trial{v.(diviner.Trial), studies[i].Name})
976+
})
977+
trialsMu.Unlock()
978+
return nil
967979
})
968980
if err != nil {
969981
log.Fatal(err)
@@ -974,27 +986,27 @@ specifying regular expressions for matching them via the flags
974986
metrics = make(map[string]bool)
975987
)
976988

977-
for _, run := range runs {
978-
if _, ok := getMetric(run, objective, *best); !ok {
989+
for _, trial := range trials {
990+
if _, ok := trial.Metrics[objective.Metric]; !ok {
979991
continue
980992
}
981-
runs[n] = run
993+
trials[n] = trial
982994
n++
983-
for _, m := range run.Metrics {
984-
for name := range m {
985-
metrics[name] = true
986-
}
995+
for name := range trial.Metrics {
996+
metrics[name] = true
997+
987998
}
988-
for name := range run.Values {
999+
for name := range trial.Values {
9891000
values[name] = true
9901001
}
9911002
}
992-
if n < len(runs) {
993-
log.Printf("skipping %d runs due to missing metrics", len(runs)-n)
1003+
if n < len(trials) {
1004+
log.Printf("skipping %d trials due to missing metrics", len(trials)-n)
1005+
trials = trials[:n]
9941006
}
995-
sort.SliceStable(runs, func(i, j int) bool {
996-
iv, _ := getMetric(runs[i], objective, *best)
997-
jv, _ := getMetric(runs[j], objective, *best)
1007+
sort.SliceStable(trials, func(i, j int) bool {
1008+
iv, _ := trials[i].Metrics[objective.Metric]
1009+
jv, _ := trials[j].Metrics[objective.Metric]
9981010
switch objective.Direction {
9991011
case diviner.Maximize:
10001012
return jv < iv
@@ -1004,8 +1016,8 @@ specifying regular expressions for matching them via the flags
10041016
panic(objective)
10051017
}
10061018
})
1007-
if *numEntries > 0 && len(runs) > *numEntries {
1008-
runs = runs[:*numEntries]
1019+
if *numEntries > 0 && len(trials) > *numEntries {
1020+
trials = trials[:*numEntries]
10091021
}
10101022
delete(metrics, objective.Metric)
10111023

@@ -1029,13 +1041,12 @@ specifying regular expressions for matching them via the flags
10291041
return metricsOrdered[i0-i].Metric < metricsOrdered[i1-i].Metric
10301042
})
10311043
}
1032-
10331044
var (
10341045
valuesOrdered = matchAndSort(values, *valuesRe)
10351046
tw tabwriter.Writer
10361047
)
10371048
tw.Init(os.Stdout, 4, 4, 1, ' ', 0)
1038-
fmt.Fprintf(&tw, "run\t%s", objective.Metric)
1049+
fmt.Fprintf(&tw, "study\t%s", objective.Metric)
10391050
if len(metricsOrdered) > 0 {
10401051
for _, metric := range metricsOrdered {
10411052
fmt.Fprint(&tw, "\t"+metric.Metric)
@@ -1045,13 +1056,18 @@ specifying regular expressions for matching them via the flags
10451056
fmt.Fprint(&tw, "\t"+strings.Join(valuesOrdered, "\t"))
10461057
}
10471058
fmt.Fprintln(&tw)
1048-
for _, run := range runs {
1049-
v, _ := getMetric(run, objective, *best)
1050-
fmt.Fprintf(&tw, "%s:%d\t%.4g", run.Study, run.Seq, v)
1059+
for _, trial := range trials {
1060+
v := trial.Metrics[objective.Metric]
1061+
sort.Slice(trial.Runs, func(i, j int) bool { return trial.Runs[i].Seq < trial.Runs[j].Seq })
1062+
seqs := make([]string, len(trial.Runs))
1063+
for i := range seqs {
1064+
seqs[i] = fmt.Sprint(trial.Runs[i].Seq)
1065+
}
1066+
fmt.Fprintf(&tw, "%s:%s\t%.4g", trial.Study, strings.Join(seqs, ","), v)
10511067
if len(metricsOrdered) > 0 {
10521068
metrics := make([]string, len(metricsOrdered))
10531069
for i, metric := range metricsOrdered {
1054-
if v, ok := getMetric(run, metric, *best); ok {
1070+
if v, ok := trial.Metrics[metric.Metric]; ok {
10551071
metrics[i] = fmt.Sprintf("%.3g", v)
10561072
} else {
10571073
metrics[i] = "NA"
@@ -1062,7 +1078,7 @@ specifying regular expressions for matching them via the flags
10621078
if len(valuesOrdered) > 0 {
10631079
values := make([]string, len(valuesOrdered))
10641080
for i, name := range valuesOrdered {
1065-
if v, ok := run.Values[name]; ok {
1081+
if v, ok := trial.Values[name]; ok {
10661082
switch v.Kind() {
10671083
default:
10681084
values[i] = fmt.Sprint(v)
@@ -1220,27 +1236,3 @@ func splitName(name string) (study string, seq uint64) {
12201236
panic(parts)
12211237
}
12221238
}
1223-
1224-
func getMetric(run diviner.Run, objective diviner.Objective, best bool) (float64, bool) {
1225-
found := false
1226-
bestValue := 0.0
1227-
for i := len(run.Metrics) - 1; i >= 0; i-- {
1228-
m, ok := run.Metrics[i][objective.Metric]
1229-
if !ok {
1230-
continue
1231-
}
1232-
if !best {
1233-
return m, ok
1234-
}
1235-
switch {
1236-
case !found:
1237-
bestValue = m
1238-
found = true
1239-
case objective.Direction == diviner.Maximize && m > bestValue:
1240-
bestValue = m
1241-
case objective.Direction == diviner.Minimize && m < bestValue:
1242-
bestValue = m
1243-
}
1244-
}
1245-
return bestValue, found
1246-
}

database.go

Lines changed: 60 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,9 @@ type Run struct {
5656
// a run.
5757
Seq uint64
5858

59+
// Replicate is the replicate of this run.
60+
Replicate int
61+
5962
// State is the current state of the run. See RunState for
6063
// descriptions of these.
6164
State RunState
@@ -90,7 +93,8 @@ type Run struct {
9093
// TODO(marius): allow other metric selection policies
9194
// (e.g., minimize train and test loss difference)
9295
func (r Run) Trial() Trial {
93-
trial := Trial{Values: r.Values, Pending: r.State != Success}
96+
trial := Trial{Values: r.Values, Pending: r.State != Success, Runs: []Run{r}}
97+
trial.Replicates.Set(r.Replicate)
9498
if len(r.Metrics) > 0 {
9599
trial.Metrics = r.Metrics[len(r.Metrics)-1]
96100
}
@@ -148,3 +152,58 @@ type Database interface {
148152
// for the run named by a study and sequence number.
149153
Logger(study string, seq uint64) io.WriteCloser
150154
}
155+
156+
// Trials queries the database db for all runs in the provided study,
157+
// and returns a set of composite trials for each replicate of a
158+
// value set. The returned map maps value sets to these composite
159+
// trials.
160+
//
161+
// Trial metrics are averaged across successful and pending runs;
162+
// flags are set on the returned trials to indicate which replicates
163+
// they comprise and whether any pending results were used.
164+
//
165+
// TODO(marius): this is a reasonable approach for some metrics, but
166+
// not for others. We should provide a way for users to (e.g., as
167+
// part of a study definition) to define their own means of defining
168+
// composite metrics, e.g., by intepreting metrics from each run, or
169+
// their outputs directly (e.g., predictions from an evaluation run).
170+
func Trials(ctx context.Context, db Database, study Study) (*Map, error) {
171+
runs, err := db.ListRuns(ctx, study.Name, Success|Pending, time.Time{})
172+
if err != nil && err != ErrNotExist {
173+
return nil, err
174+
}
175+
replicates := NewMap()
176+
for i := range runs {
177+
var trials []Trial
178+
if v, ok := replicates.Get(runs[i].Values); ok {
179+
trials = v.([]Trial)
180+
}
181+
trials = append(trials, runs[i].Trial())
182+
replicates.Put(runs[i].Values, trials)
183+
}
184+
trials := NewMap()
185+
replicates.Range(func(key Value, v interface{}) {
186+
var (
187+
reps = v.([]Trial)
188+
counts = make(map[string]int)
189+
values = key.(Values)
190+
trial = Trial{Values: values, Metrics: make(Metrics)}
191+
)
192+
for _, rep := range reps {
193+
if trial.Replicates&rep.Replicates != 0 {
194+
// TODO(marius): pick "best" replicate?
195+
continue
196+
}
197+
for name, value := range rep.Metrics {
198+
counts[name]++
199+
n := float64(counts[name])
200+
trial.Metrics[name] = value/n + trial.Metrics[name]*(n-1)/n
201+
}
202+
trial.Pending = trial.Pending || rep.Pending
203+
trial.Replicates |= rep.Replicates
204+
trial.Runs = append(trial.Runs, rep.Runs...)
205+
}
206+
trials.Put(&values, trial)
207+
})
208+
return trials, nil
209+
}

0 commit comments

Comments
 (0)