Skip to content

Commit

Permalink
2d example
Browse files Browse the repository at this point in the history
  • Loading branch information
praksharma committed Mar 11, 2024
1 parent 15853e1 commit b6b9fc3
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 58 deletions.
5 changes: 2 additions & 3 deletions DeepINN/domain/standard.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,9 @@ def __init__(self, pde_equation, collocation_object, bc_object) -> None:
super().__init__()

self.pde = pde_equation
self.pde_sampler = collocation_object
self.pde_sampler = collocation_object # class of the PDE
# if the bcs is not list, then make it a list
self.bc_sampler = bc_object if isinstance(bc_object, (list)) else [bc_object]

self.bc_sampler = bc_object if isinstance(bc_object, (list)) else [bc_object] # class of the boundary condition
self.bc_list_len = len(self.bc_sampler)

def sample_collocation_points(self):
Expand Down
7 changes: 5 additions & 2 deletions DeepINN/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,10 @@ def compile(self, optimiser_string : str, lr : float, metrics_string : str, devi

def compile_domain(self):
# sample collocation points
self.collocation_point_sample, self.collocation_point_labels = self.domain.sample_collocation_labels()
self.collocation_point_sample, self.collocation_point_labels = self.domain.sample_collocation_labels() # list of collocation points and their labels both as tensors

# sample boundary points
self.boundary_point_sample, self.boundary_point_labels = self.domain.sample_boundary_labels()
self.boundary_point_sample, self.boundary_point_labels = self.domain.sample_boundary_labels() # list of boundary points and their labels both as tensors
print("Domain compiled", file=sys.stderr, flush=True)

def compile_network(self):
Expand Down Expand Up @@ -61,6 +61,9 @@ def initialise_training(self, iterations : int = None):
if self.boundary_point_sample[0].size()[0] == 1: # if row is 1 in the particular boundary tensor
self.boundary_point_sample = torch.cat(self.boundary_point_sample, dim=0)
self.boundary_point_labels = torch.cat(self.boundary_point_labels, dim=0)
else:
self.boundary_point_sample = torch.cat(self.boundary_point_sample, dim=0)
self.boundary_point_labels = torch.cat(self.boundary_point_labels, dim=0)

# Set requires_grad=True for self.collocation_point_sample
self.collocation_point_sample.requires_grad = True
Expand Down
Loading

0 comments on commit b6b9fc3

Please sign in to comment.