Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor out shared LSTM/GGNN training loop #126

Draft
wants to merge 1 commit into
base: development
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions programl/task/dataflow/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ py_library(
deps = [
":graph_loader",
"//programl/models:async_batch_builder",
"//programl/models:base_batch_builder",
"//programl/models:model",
"//programl/models/ggnn",
"//programl/proto:checkpoint_py",
"//programl/proto:epoch_py",
Expand Down
71 changes: 70 additions & 1 deletion programl/task/dataflow/dataflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,11 @@
import warnings
from typing import Tuple

from labm8.py import app, pbutil
from labm8.py import app, humanize, pbutil
from sklearn.exceptions import UndefinedMetricWarning

from programl.models.base_batch_builder import BaseBatchBuilder
from programl.models.model import Model
from programl.proto import checkpoint_pb2, epoch_pb2

app.DEFINE_string(
Expand Down Expand Up @@ -208,3 +210,70 @@ def CreateLoggingDirectories(
(log_dir / "checkpoints").mkdir()
(log_dir / "graph_loader").mkdir()
return log_dir


def run_training_loop(
log_dir: pathlib.Path,
epochs,
val_batches: BaseBatchBuilder,
start_epoch_step: int,
model: Model,
val_graph_count: int,
) -> pathlib.Path:
"""

Args:
log_dir: The logging directory.
epochs: An epoch batch builder.
val_batches: A batch builder for validation.
start_epoch_step: The initial step count.
model: The model to train.
val_graph_count: The number of validation graphs.

Returns:
The log_dir first argument.
"""
for (
epoch_step,
(train_graph_count, train_graph_cumsum, train_batches),
) in enumerate(epochs, start=start_epoch_step):
start_time = time.time()
hr_graph_cumsum = f"{humanize.Commas(train_graph_cumsum)} graphs"

train_results = model.RunBatches(
epoch_pb2.TRAIN,
train_batches,
log_prefix=f"Train to {hr_graph_cumsum}",
total_graph_count=train_graph_count,
)
val_results = model.RunBatches(
epoch_pb2.VAL,
val_batches.batches,
log_prefix=f"Val at {hr_graph_cumsum}",
total_graph_count=val_graph_count,
)

# Write the epoch to file as an epoch list. This may seem redundant since
# epoch list contains a single item, but it means that we can easily
# concatenate a sequence of these epoch protos to produce a valid epoch
# list using: `cat *.EpochList.pbtxt > epochs.pbtxt`
epoch = epoch_pb2.EpochList(
epoch=[
epoch_pb2.Epoch(
walltime_seconds=time.time() - start_time,
epoch_num=epoch_step,
train_results=train_results,
val_results=val_results,
)
]
)
print(epoch, end="")

epoch_path = log_dir / "epochs" / f"{epoch_step:03d}.EpochList.pbtxt"
pbutil.ToFile(epoch, epoch_path)
app.Log(1, "Wrote %s", epoch_path)

checkpoint_path = log_dir / "checkpoints" / f"{epoch_step:03d}.Checkpoint.pb"
pbutil.ToFile(model.SaveCheckpoint(), checkpoint_path)

return log_dir
47 changes: 3 additions & 44 deletions programl/task/dataflow/ggnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,50 +173,9 @@ def TrainDataflowGGNN(
)
)

for (
epoch_step,
(train_graph_count, train_graph_cumsum, train_batches),
) in enumerate(epochs, start=start_epoch_step):
start_time = time.time()
hr_graph_cumsum = f"{humanize.Commas(train_graph_cumsum)} graphs"

train_results = model.RunBatches(
epoch_pb2.TRAIN,
train_batches,
log_prefix=f"Train to {hr_graph_cumsum}",
total_graph_count=train_graph_count,
)
val_results = model.RunBatches(
epoch_pb2.VAL,
val_batches.batches,
log_prefix=f"Val at {hr_graph_cumsum}",
total_graph_count=val_graph_count,
)

# Write the epoch to file as an epoch list. This may seem redundant since
# epoch list contains a single item, but it means that we can easily
# concatenate a sequence of these epoch protos to produce a valid epoch
# list using: `cat *.EpochList.pbtxt > epochs.pbtxt`
epoch = epoch_pb2.EpochList(
epoch=[
epoch_pb2.Epoch(
walltime_seconds=time.time() - start_time,
epoch_num=epoch_step,
train_results=train_results,
val_results=val_results,
)
]
)
print(epoch, end="")

epoch_path = log_dir / "epochs" / f"{epoch_step:03d}.EpochList.pbtxt"
pbutil.ToFile(epoch, epoch_path)
app.Log(1, "Wrote %s", epoch_path)

checkpoint_path = log_dir / "checkpoints" / f"{epoch_step:03d}.Checkpoint.pb"
pbutil.ToFile(model.SaveCheckpoint(), checkpoint_path)

return log_dir
return dataflow.run_training_loop(
log_dir, epochs, val_batches, start_epoch_step, model, val_graph_count
)


def TestDataflowGGNN(
Expand Down
47 changes: 3 additions & 44 deletions programl/task/dataflow/train_lstm.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,50 +160,9 @@ def TrainDataflowLSTM(
)
)

for (
epoch_step,
(train_graph_count, train_graph_cumsum, train_batches),
) in enumerate(epochs, start=start_epoch_step):
start_time = time.time()
hr_graph_cumsum = f"{humanize.Commas(train_graph_cumsum)} graphs"

train_results = model.RunBatches(
epoch_pb2.TRAIN,
train_batches,
log_prefix=f"Train to {hr_graph_cumsum}",
total_graph_count=train_graph_count,
)
val_results = model.RunBatches(
epoch_pb2.VAL,
val_batches.batches,
log_prefix=f"Val at {hr_graph_cumsum}",
total_graph_count=FLAGS.val_graph_count,
)

# Write the epoch to file as an epoch list. This may seem redundant since
# epoch list contains a single item, but it means that we can easily
# concatenate a sequence of these epoch protos to produce a valid epoch
# list using: `cat *.EpochList.pbtxt > epochs.pbtxt`
epoch = epoch_pb2.EpochList(
epoch=[
epoch_pb2.Epoch(
walltime_seconds=time.time() - start_time,
epoch_num=epoch_step,
train_results=train_results,
val_results=val_results,
)
]
)
print(epoch, end="")

epoch_path = log_dir / "epochs" / f"{epoch_step:03d}.EpochList.pbtxt"
pbutil.ToFile(epoch, epoch_path)
app.Log(1, "Wrote %s", epoch_path)

checkpoint_path = log_dir / "checkpoints" / f"{epoch_step:03d}.Checkpoint.pb"
pbutil.ToFile(model.SaveCheckpoint(), checkpoint_path)

return log_dir
return dataflow.run_training_loop(
log_dir, epochs, val_batches, start_epoch_step, model, FLAGS.val_graph_count
)


def TestDataflowLSTM(
Expand Down