-
Notifications
You must be signed in to change notification settings - Fork 2.2k
/
Copy pathutilities.py
467 lines (378 loc) · 14.2 KB
/
utilities.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
import base64
import collections
import copy
import json
import math
import os
import re
import tempfile
import uuid
from contextlib import contextmanager
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
Iterable,
Iterator,
List,
Optional,
Protocol,
Sequence,
Tuple,
Type,
Union,
runtime_checkable,
)
from urllib.parse import urlparse, uses_netloc, uses_params, uses_relative
import numpy as np
from branca.element import Div, Element, Figure
# import here for backwards compatibility
from branca.utilities import ( # noqa F401
_locations_mirror,
_parse_size,
none_max,
none_min,
write_png,
)
try:
import pandas as pd
except ImportError:
pd = None
if TYPE_CHECKING:
from .features import Popup
TypeLine = Iterable[Sequence[float]]
TypeMultiLine = Union[TypeLine, Iterable[TypeLine]]
TypeJsonValueNoNone = Union[str, float, bool, Sequence, dict]
TypeJsonValue = Union[TypeJsonValueNoNone, None]
TypePathOptions = Union[bool, str, float, None]
TypeBounds = Sequence[Sequence[float]]
TypeBoundsReturn = List[List[Optional[float]]]
TypeContainer = Union[Figure, Div, "Popup"]
_VALID_URLS = set(uses_relative + uses_netloc + uses_params)
_VALID_URLS.discard("")
_VALID_URLS.add("data")
@runtime_checkable
class TypeJsCode(Protocol):
# we only care about this attribute.
js_code: str
class JsCode(TypeJsCode):
"""Wrapper around Javascript code."""
def __init__(self, js_code: Union[str, "JsCode"]):
if isinstance(js_code, JsCode):
self.js_code: str = js_code.js_code
else:
self.js_code = js_code
def __str__(self):
return self.js_code
def validate_location(location: Sequence[float]) -> List[float]:
"""Validate a single lat/lon coordinate pair and convert to a list
Validate that location:
* is a sized variable
* with size 2
* allows indexing (i.e. has an ordering)
* where both values are floats (or convertible to float)
* and both values are not NaN
"""
if isinstance(location, np.ndarray) or (
pd is not None and isinstance(location, pd.DataFrame)
):
location = np.squeeze(location).tolist()
if not hasattr(location, "__len__"):
raise TypeError(
"Location should be a sized variable, "
"for example a list or a tuple, instead got "
f"{location!r} of type {type(location)}."
)
if len(location) != 2:
raise ValueError(
"Expected two (lat, lon) values for location, "
f"instead got: {location!r}."
)
try:
coords = (location[0], location[1])
except (TypeError, KeyError):
raise TypeError(
"Location should support indexing, like a list or "
f"a tuple does, instead got {location!r} of type {type(location)}."
)
for coord in coords:
try:
float(coord)
except (TypeError, ValueError):
raise ValueError(
"Location should consist of two numerical values, "
f"but {coord!r} of type {type(coord)} is not convertible to float."
)
if math.isnan(float(coord)):
raise ValueError("Location values cannot contain NaNs.")
return [float(x) for x in coords]
def _validate_locations_basics(locations: TypeMultiLine) -> None:
"""Helper function that does basic validation of line and multi-line types."""
try:
iter(locations)
except TypeError:
raise TypeError(
"Locations should be an iterable with coordinate pairs,"
f" but instead got {locations!r}."
)
try:
next(iter(locations))
except StopIteration:
raise ValueError("Locations is empty.")
def validate_locations(locations: TypeLine) -> List[List[float]]:
"""Validate an iterable with lat/lon coordinate pairs."""
locations = if_pandas_df_convert_to_numpy(locations)
_validate_locations_basics(locations)
return [validate_location(coord_pair) for coord_pair in locations]
def validate_multi_locations(
locations: TypeMultiLine,
) -> Union[List[List[float]], List[List[List[float]]]]:
"""Validate an iterable with possibly nested lists of coordinate pairs."""
locations = if_pandas_df_convert_to_numpy(locations)
_validate_locations_basics(locations)
try:
float(next(iter(next(iter(next(iter(locations))))))) # type: ignore
except (TypeError, StopIteration):
# locations is a list of coordinate pairs
return [validate_location(coord_pair) for coord_pair in locations] # type: ignore
else:
# locations is a list of a list of coordinate pairs, recurse
return [validate_locations(lst) for lst in locations] # type: ignore
def if_pandas_df_convert_to_numpy(obj: Any) -> Any:
"""Return a Numpy array from a Pandas dataframe.
Iterating over a DataFrame has weird side effects, such as the first
row being the column names. Converting to Numpy is more safe.
"""
if pd is not None and isinstance(obj, pd.DataFrame):
return obj.values
else:
return obj
def image_to_url(
image: Any,
colormap: Optional[Callable] = None,
origin: str = "upper",
) -> str:
"""
Infers the type of an image argument and transforms it into a URL.
Parameters
----------
image: string, file or array-like object
* If string, it will be written directly in the output file.
* If file, it's content will be converted as embedded in the
output file.
* If array-like, it will be converted to PNG base64 string and
embedded in the output.
origin: ['upper' | 'lower'], optional, default 'upper'
Place the [0, 0] index of the array in the upper left or
lower left corner of the axes.
colormap: callable, used only for `mono` image.
Function of the form [x -> (r,g,b)] or [x -> (r,g,b,a)]
for transforming a mono image into RGB.
It must output iterables of length 3 or 4, with values between
0. and 1. You can use colormaps from `matplotlib.cm`.
"""
if isinstance(image, str) and not _is_url(image):
fileformat = os.path.splitext(image)[-1][1:]
with open(image, "rb") as f:
img = f.read()
b64encoded = base64.b64encode(img).decode("utf-8")
url = f"data:image/{fileformat};base64,{b64encoded}"
elif "ndarray" in image.__class__.__name__:
img = write_png(image, origin=origin, colormap=colormap)
b64encoded = base64.b64encode(img).decode("utf-8")
url = f"data:image/png;base64,{b64encoded}"
else:
# Round-trip to ensure a nice formatted json.
url = json.loads(json.dumps(image))
return url.replace("\n", " ")
def _is_url(url: str) -> bool:
"""Check to see if `url` has a valid protocol."""
try:
return urlparse(url).scheme in _VALID_URLS
except Exception:
return False
def mercator_transform(
data: Any,
lat_bounds: Tuple[float, float],
origin: str = "upper",
height_out: Optional[int] = None,
) -> np.ndarray:
"""
Transforms an image computed in (longitude,latitude) coordinates into
the a Mercator projection image.
Parameters
----------
data: numpy array or equivalent list-like object.
Must be NxM (mono), NxMx3 (RGB) or NxMx4 (RGBA)
lat_bounds : length 2 tuple
Minimal and maximal value of the latitude of the image.
Bounds must be between -85.051128779806589 and 85.051128779806589
otherwise they will be clipped to that values.
origin : ['upper' | 'lower'], optional, default 'upper'
Place the [0,0] index of the array in the upper left or lower left
corner of the axes.
height_out : int, default None
The expected height of the output.
If None, the height of the input is used.
See https://en.wikipedia.org/wiki/Web_Mercator for more details.
"""
def mercator(x):
return np.arcsinh(np.tan(x * np.pi / 180.0)) * 180.0 / np.pi
array = np.atleast_3d(data).copy()
height, width, nblayers = array.shape
lat_min = max(lat_bounds[0], -85.051128779806589)
lat_max = min(lat_bounds[1], 85.051128779806589)
if height_out is None:
height_out = height
# Eventually flip the image
if origin == "upper":
array = array[::-1, :, :]
lats = lat_min + np.linspace(0.5 / height, 1.0 - 0.5 / height, height) * (
lat_max - lat_min
)
latslats = mercator(lat_min) + np.linspace(
0.5 / height_out, 1.0 - 0.5 / height_out, height_out
) * (mercator(lat_max) - mercator(lat_min))
out = np.zeros((height_out, width, nblayers))
for i in range(width):
for j in range(nblayers):
out[:, i, j] = np.interp(latslats, mercator(lats), array[:, i, j])
# Eventually flip the image.
if origin == "upper":
out = out[::-1, :, :]
return out
def iter_coords(obj: Any) -> Iterator[Tuple[float, ...]]:
"""
Returns all the coordinate tuples from a geometry or feature.
"""
if isinstance(obj, (tuple, list)):
coords = obj
elif "features" in obj:
coords = [
geom["geometry"]["coordinates"]
for geom in obj["features"]
if geom["geometry"]
]
elif "geometry" in obj:
coords = obj["geometry"]["coordinates"] if obj["geometry"] else []
elif (
"geometries" in obj
and obj["geometries"][0]
and "coordinates" in obj["geometries"][0]
):
coords = obj["geometries"][0]["coordinates"]
else:
coords = obj.get("coordinates", obj)
for coord in coords:
if isinstance(coord, (float, int)):
yield tuple(coords)
break
else:
yield from iter_coords(coord)
def get_bounds(
locations: Any,
lonlat: bool = False,
) -> List[List[Optional[float]]]:
"""
Computes the bounds of the object in the form
[[lat_min, lon_min], [lat_max, lon_max]]
"""
bounds: List[List[Optional[float]]] = [[None, None], [None, None]]
for point in iter_coords(locations):
bounds = [
[
none_min(bounds[0][0], point[0]),
none_min(bounds[0][1], point[1]),
],
[
none_max(bounds[1][0], point[0]),
none_max(bounds[1][1], point[1]),
],
]
if lonlat:
bounds = _locations_mirror(bounds)
return bounds
def normalize_bounds_type(bounds: TypeBounds) -> TypeBoundsReturn:
return [[float(x) if x is not None else None for x in y] for y in bounds]
def camelize(key: str) -> str:
"""Convert a python_style_variable_name to lowerCamelCase.
Examples
--------
>>> camelize("variable_name")
'variableName'
>>> camelize("variableName")
'variableName'
"""
return "".join(x.capitalize() if i > 0 else x for i, x in enumerate(key.split("_")))
def compare_rendered(obj1: str, obj2: str) -> bool:
"""
Return True/False if the normalized rendered version of
two folium map objects are the equal or not.
"""
return normalize(obj1) == normalize(obj2)
def normalize(rendered: str) -> str:
"""Return the input string without non-functional spaces or newlines."""
out = "".join([line.strip() for line in rendered.splitlines() if line.strip()])
out = out.replace(", ", ",")
return out
@contextmanager
def temp_html_filepath(data: str) -> Iterator[str]:
"""Yields the path of a temporary HTML file containing data."""
filepath = ""
try:
fid, filepath = tempfile.mkstemp(suffix=".html", prefix="folium_")
os.write(fid, data.encode("utf8") if isinstance(data, str) else data)
os.close(fid)
yield filepath
finally:
if os.path.isfile(filepath):
os.remove(filepath)
def deep_copy(item_original: Element) -> Element:
"""Return a recursive deep-copy of item where each copy has a new ID."""
item = copy.copy(item_original)
item._id = uuid.uuid4().hex
if hasattr(item, "_children") and len(item._children) > 0:
children_new = collections.OrderedDict()
for subitem_original in item._children.values():
subitem = deep_copy(subitem_original)
subitem._parent = item
children_new[subitem.get_name()] = subitem
item._children = children_new
return item
def get_obj_in_upper_tree(element: Element, cls: Type) -> Element:
"""Return the first object in the parent tree of class `cls`."""
parent = element._parent
if parent is None:
raise ValueError(f"The top of the tree was reached without finding a {cls}")
if not isinstance(parent, cls):
return get_obj_in_upper_tree(parent, cls)
return parent
def parse_options(**kwargs: TypeJsonValue) -> Dict[str, TypeJsonValueNoNone]:
"""Return a dict with lower-camelcase keys and non-None values.."""
return {camelize(key): value for key, value in kwargs.items() if value is not None}
def remove_empty(**kwargs: TypeJsonValue) -> Dict[str, TypeJsonValueNoNone]:
"""Return a dict without None values."""
return {key: value for key, value in kwargs.items() if value is not None}
def escape_backticks(text: str) -> str:
"""Escape backticks so text can be used in a JS template."""
return re.sub(r"(?<!\\)`", r"\`", text)
def escape_double_quotes(text: str) -> str:
return text.replace('"', r"\"")
def javascript_identifier_path_to_array_notation(path: str) -> str:
"""Convert a path like obj1.obj2 to array notation: ["obj1"]["obj2"]."""
return "".join(f'["{escape_double_quotes(x)}"]' for x in path.split("."))
def get_and_assert_figure_root(obj: Element) -> Figure:
"""Return the root element of the tree and assert it's a Figure."""
figure = obj.get_root()
assert isinstance(
figure, Figure
), "You cannot render this Element if it is not in a Figure."
return figure
def parse_font_size(value: Union[str, int, float]) -> str:
"""Parse a font size value, if number set as px"""
if isinstance(value, (int, float)):
return f"{value}px"
if (value[-3:] != "rem") and (value[-2:] not in ["em", "px"]):
raise ValueError("The font size must be expressed in rem, em, or px.")
return value