1
+ import idna
1
2
import logging
2
3
import os
3
4
6
7
from pyspark .sql .types import BooleanType , LongType , StringType , StructField , StructType
7
8
8
9
from iana_tld import iana_tld_list
10
+ from wat_extract_links import ExtractHostLinksJob
9
11
10
12
11
13
class HostLinksToGraph (CCSparkJob ):
@@ -17,6 +19,9 @@ class HostLinksToGraph(CCSparkJob):
17
19
def add_arguments (self , parser ):
18
20
parser .add_argument ("--save_as_text" , type = str , default = None ,
19
21
help = "Save webgraph also as text on path" )
22
+ parser .add_argument ("--normalize_host_names" , action = 'store_true' ,
23
+ help = "Normalize host names: replace Unicode IDNs"
24
+ " by their ASCII equivalents" )
20
25
parser .add_argument ("--validate_host_names" , action = 'store_true' ,
21
26
help = "Validate host names and skip vertices with"
22
27
" invalid name during assignment of vertex IDs" )
@@ -42,6 +47,8 @@ def reverse_host(host):
42
47
43
48
@staticmethod
44
49
def reverse_host_is_valid (rev_host ):
50
+ if rev_host is None :
51
+ return False
45
52
if '.' not in rev_host :
46
53
return False
47
54
# fast check for valid top-level domain
@@ -52,17 +59,39 @@ def reverse_host_is_valid(rev_host):
52
59
return False
53
60
return True
54
61
62
+ @staticmethod
63
+ def reverse_host_normalize (rev_host ):
64
+ parts = rev_host .split ('.' )
65
+ modified = False
66
+ for (i , part ) in enumerate (parts ):
67
+ if not ExtractHostLinksJob .host_part_pattern .match (part ):
68
+ try :
69
+ idn = idna .encode (part ).decode ('ascii' )
70
+ parts [i ] = idn
71
+ modified = True
72
+ except (idna .IDNAError , idna .core .InvalidCodepoint , UnicodeError , IndexError , Exception ):
73
+ return None
74
+ if modified :
75
+ return '.' .join (parts )
76
+ return rev_host
77
+
55
78
def vertices_assign_ids (self , session , edges ):
56
79
source = edges .select (edges .s .alias ('name' ))
57
80
target = edges .select (edges .t .alias ('name' ))
58
81
59
82
ids = source .union (target ) \
60
83
.distinct ()
61
84
85
+ if self .args .normalize_host_names :
86
+ normalize = sqlf .udf (HostLinksToGraph .reverse_host_normalize ,
87
+ StringType ())
88
+ ids = ids .withColumn ('name' , normalize (ids ['name' ]))
89
+ ids = ids .dropna ().distinct ()
90
+
62
91
if self .args .validate_host_names :
63
92
is_valid = sqlf .udf (HostLinksToGraph .reverse_host_is_valid ,
64
93
BooleanType ())
65
- ids = ids .filter (is_valid (ids . name ))
94
+ ids = ids .filter (is_valid (ids [ ' name' ] ))
66
95
67
96
if self .args .vertex_partitions == 1 :
68
97
ids = ids \
@@ -104,6 +133,7 @@ def run_job(self, session):
104
133
for add_input in self .args .add_input :
105
134
add_edges = session .read .load (add_input )
106
135
edges = edges .union (add_edges )
136
+
107
137
# remove duplicates and sort
108
138
edges = edges \
109
139
.dropDuplicates () \
0 commit comments