Skip to content

Commit

Permalink
training timer implemented
Browse files Browse the repository at this point in the history
  • Loading branch information
praksharma committed Mar 11, 2024
1 parent be0ec3c commit a829089
Show file tree
Hide file tree
Showing 5 changed files with 53 additions and 37 deletions.
22 changes: 16 additions & 6 deletions DeepINN/model.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import torch
import sys
from .backend import loss_metric, choose_optimiser
from .config import Config
from .config import Config
from .utils import timer

class Model():
"""
Expand Down Expand Up @@ -46,8 +47,7 @@ def compile_network(self):
)
print("Network compiled", file=sys.stderr, flush=True)

def train(self, iterations : int = None, display_every : int = None):

def initialise_training(self, iterations : int = None):
if self.iter == 0: # We are running a fresh training
self.training_history = [] # Initialize an empty list for storing loss values
self.iterations = iterations
Expand All @@ -65,6 +65,18 @@ def train(self, iterations : int = None, display_every : int = None):
# Set requires_grad=True for self.collocation_point_sample
self.collocation_point_sample.requires_grad = True

def train(self, iterations : int = None, display_every : int = 1):
"""_summary_
Args:
iterations (int): _description_. Number of iterations.
display_every (int, optional): _description_. Display the loss every display_every iterations. Defaults to 1.
"""
self.initialise_training(iterations)
self.trainer()

@timer
def trainer(self):
# implement training loop
while self.iter <= self.iterations:

Expand Down Expand Up @@ -93,6 +105,4 @@ def train(self, iterations : int = None, display_every : int = None):
self.iter = self.iter + 1
else:
print('Training finished')
#elapsed = time.time() - start_time
#print('Training time: %.2f' % (elapsed))
#print(f"Final loss: {total_loss}")

11 changes: 1 addition & 10 deletions DeepINN/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,3 @@
"""Useful helper methods for the definition and evaluation of a problem.
For the creation of conditions, some differential operators are implemented under
torchphysics.utils.differentialoperators.
For the evaluation of the trained model, some plot and animation functionalities are provided.
They can give you a rough overview of the determined solution. These lay under
torchphysics.utils.plotting
"""
from .differentialoperators import (laplacian,
grad,
div,
Expand All @@ -20,7 +11,7 @@

from .data import PointsDataset, PointsDataLoader, DeepONetDataLoader

from .user_fun import UserFunction, tensor2numpy
from .user_fun import UserFunction, tensor2numpy, timer
from .plotting import plot, animate, scatter
from .evaluation import compute_min_and_max

Expand Down
13 changes: 13 additions & 0 deletions DeepINN/utils/user_fun.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import inspect
import copy
import torch
import functools
import time

from ..geometry.spaces.points import Points

Expand Down Expand Up @@ -317,3 +319,14 @@ def tensor2numpy(tensor_list):
Converts a list of torch.tensors to numpy arrays.
"""
return [tensor.detach().cpu().numpy() for tensor in tensor_list]

def timer(func):
"""Print the runtime of the decorated function"""
@functools.wraps(func)
def wrapper_timer(*args, **kwargs):
start_time = time.perf_counter()
func(*args, **kwargs) # execute the decorated function
end_time = time.perf_counter()
run_time = end_time - start_time
print(f"Time taken: {func.__name__!r} in {run_time:.4f} secs")
return wrapper_timer
Loading

0 comments on commit a829089

Please sign in to comment.