99import time
1010import os
1111import re
12+ from urllib .parse import urlparse
13+ from urllib .request import urlretrieve
1214
1315import networkx as nx
1416import numpy as np
2123
2224from typing import Union
2325
24- from .utils import *
26+ from . import utils
2527
2628
2729class EpiModel (object ):
@@ -325,7 +327,7 @@ def plot(
325327 ax = plt .gca ()
326328
327329 for comp in self .values_ .columns :
328- (self .values_ [comp ] / N ).plot (c = epi_colors [comp [0 ]], ** kwargs )
330+ (self .values_ [comp ] / N ).plot (c = utils . EPI_COLORS [comp [0 ]], ** kwargs )
329331
330332 ax .legend (self .values_ .columns )
331333 ax .set_xlabel ("Time" )
@@ -340,7 +342,7 @@ def plot(
340342 return ax
341343 except Exception as e :
342344 print (e )
343- raise NotInitialized ("You must call integrate() or simulate() first" )
345+ raise utils . NotInitialized ("You must call integrate() or simulate() first" )
344346
345347 def __getattr__ (self , name : str ) -> pd .Series :
346348 """
@@ -670,7 +672,7 @@ def draw_model(self, ax: Union[plt.Axes, None] = None, show: bool = True) -> Non
670672 orig_pos = pos [node [3 :]]
671673 pos [node ] = [orig_pos [0 ], orig_pos [1 ] - 1 ]
672674 else :
673- node_colors .append (epi_colors [node [0 ]])
675+ node_colors .append (utils . EPI_COLORS [node [0 ]])
674676
675677 edge_labels = {}
676678
@@ -821,3 +823,27 @@ def load_model(filename: str) -> None:
821823 model .name = data [key ]
822824
823825 return model
826+
827+
828+ def download_model (filename : str , repo : Union [str , None ] = None ) -> None :
829+ """
830+ Download model from offical repository
831+ """
832+ if repo is None :
833+ repo = utils .OFFICIAL_REPO
834+
835+ parsed_repo = urlparse (repo )
836+
837+ if parsed_repo .netloc == 'github.com' :
838+ repo = repo .replace ('github.com' , 'raw.githubusercontent.com' )
839+ remote_path = repo + os .path .join ('refs/heads/models/models/' , filename )
840+
841+ if not os .path .exists (utils .LOCAL_DIRECTORY ):
842+ os .makedirs (utils .LOCAL_DIRECTORY )
843+
844+ local_path = os .path .join (utils .LOCAL_DIRECTORY , filename )
845+
846+ if not os .path .exists (local_path ):
847+ urlretrieve (remote_path , local_path )
848+
849+ return EpiModel .load_model (local_path )
0 commit comments