Skip to content

Commit ccb7509

Browse files
authored
Fix type hints after adding Branca type checking (#2060)
* remove render() return types * fix TypeBounds * missing return statement * split TypeBounds in input and return types * deal with bounds from args to return * fix VegaLite typing * geojsondetail assert parent is geojson * bin_edges in choropleth * geojson/topojson in choropleth * colormap type in ColorLine * ruff check * black * fix circular import
1 parent 09c5905 commit ccb7509

8 files changed

+96
-45
lines changed

Diff for: folium/elements.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ class JSCSSMixin(Element):
1212
default_js: List[Tuple[str, str]] = []
1313
default_css: List[Tuple[str, str]] = []
1414

15-
def render(self, **kwargs) -> None:
15+
def render(self, **kwargs):
1616
figure = self.get_root()
1717
assert isinstance(
1818
figure, Figure

Diff for: folium/features.py

+50-28
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,24 @@
1212
import numpy as np
1313
import requests
1414
from branca.colormap import ColorMap, LinearColormap, StepColormap
15-
from branca.element import Element, Figure, Html, IFrame, JavascriptLink, MacroElement
15+
from branca.element import (
16+
Div,
17+
Element,
18+
Figure,
19+
Html,
20+
IFrame,
21+
JavascriptLink,
22+
MacroElement,
23+
)
1624
from branca.utilities import color_brewer
1725

1826
from folium.elements import JSCSSMixin
1927
from folium.folium import Map
2028
from folium.map import FeatureGroup, Icon, Layer, Marker, Popup, Tooltip
2129
from folium.template import Template
2230
from folium.utilities import (
31+
TypeBoundsReturn,
32+
TypeContainer,
2333
TypeJsonValue,
2434
TypeLine,
2535
TypePathOptions,
@@ -165,7 +175,7 @@ def __init__(
165175
self.top = _parse_size(top)
166176
self.position = position
167177

168-
def render(self, **kwargs) -> None:
178+
def render(self, **kwargs):
169179
"""Renders the HTML representation of the element."""
170180
super().render(**kwargs)
171181

@@ -284,9 +294,15 @@ def __init__(
284294
self.top = _parse_size(top)
285295
self.position = position
286296

287-
def render(self, **kwargs) -> None:
297+
def render(self, **kwargs):
288298
"""Renders the HTML representation of the element."""
289-
self._parent.html.add_child(
299+
parent = self._parent
300+
if not isinstance(parent, (Figure, Div, Popup)):
301+
raise TypeError(
302+
"VegaLite elements can only be added to a Figure, Div, or Popup"
303+
)
304+
305+
parent.html.add_child(
290306
Element(
291307
Template(
292308
"""
@@ -331,7 +347,7 @@ def render(self, **kwargs) -> None:
331347
embed_vegalite = embed_mapping.get(
332348
self.vegalite_major_version, self._embed_vegalite_v2
333349
)
334-
embed_vegalite(figure)
350+
embed_vegalite(figure=figure, parent=parent)
335351

336352
@property
337353
def vegalite_major_version(self) -> Optional[int]:
@@ -342,8 +358,8 @@ def vegalite_major_version(self) -> Optional[int]:
342358

343359
return int(schema.split("/")[-1].split(".")[0].lstrip("v"))
344360

345-
def _embed_vegalite_v5(self, figure: Figure) -> None:
346-
self._vega_embed()
361+
def _embed_vegalite_v5(self, figure: Figure, parent: TypeContainer) -> None:
362+
self._vega_embed(parent=parent)
347363

348364
figure.header.add_child(
349365
JavascriptLink("https://cdn.jsdelivr.net/npm//vega@5"), name="vega"
@@ -356,8 +372,8 @@ def _embed_vegalite_v5(self, figure: Figure) -> None:
356372
name="vega-embed",
357373
)
358374

359-
def _embed_vegalite_v4(self, figure: Figure) -> None:
360-
self._vega_embed()
375+
def _embed_vegalite_v4(self, figure: Figure, parent: TypeContainer) -> None:
376+
self._vega_embed(parent=parent)
361377

362378
figure.header.add_child(
363379
JavascriptLink("https://cdn.jsdelivr.net/npm//vega@5"), name="vega"
@@ -370,8 +386,8 @@ def _embed_vegalite_v4(self, figure: Figure) -> None:
370386
name="vega-embed",
371387
)
372388

373-
def _embed_vegalite_v3(self, figure: Figure) -> None:
374-
self._vega_embed()
389+
def _embed_vegalite_v3(self, figure: Figure, parent: TypeContainer) -> None:
390+
self._vega_embed(parent=parent)
375391

376392
figure.header.add_child(
377393
JavascriptLink("https://cdn.jsdelivr.net/npm/vega@4"), name="vega"
@@ -384,8 +400,8 @@ def _embed_vegalite_v3(self, figure: Figure) -> None:
384400
name="vega-embed",
385401
)
386402

387-
def _embed_vegalite_v2(self, figure: Figure) -> None:
388-
self._vega_embed()
403+
def _embed_vegalite_v2(self, figure: Figure, parent: TypeContainer) -> None:
404+
self._vega_embed(parent=parent)
389405

390406
figure.header.add_child(
391407
JavascriptLink("https://cdn.jsdelivr.net/npm/vega@3"), name="vega"
@@ -398,8 +414,8 @@ def _embed_vegalite_v2(self, figure: Figure) -> None:
398414
name="vega-embed",
399415
)
400416

401-
def _vega_embed(self) -> None:
402-
self._parent.script.add_child(
417+
def _vega_embed(self, parent: TypeContainer) -> None:
418+
parent.script.add_child(
403419
Element(
404420
Template(
405421
"""
@@ -412,8 +428,8 @@ def _vega_embed(self) -> None:
412428
name=self.get_name(),
413429
)
414430

415-
def _embed_vegalite_v1(self, figure: Figure) -> None:
416-
self._parent.script.add_child(
431+
def _embed_vegalite_v1(self, figure: Figure, parent: TypeContainer) -> None:
432+
parent.script.add_child(
417433
Element(
418434
Template(
419435
"""
@@ -436,19 +452,19 @@ def _embed_vegalite_v1(self, figure: Figure) -> None:
436452
figure.header.add_child(
437453
JavascriptLink("https://cdnjs.cloudflare.com/ajax/libs/vega/2.6.5/vega.js"),
438454
name="vega",
439-
) # noqa
455+
)
440456
figure.header.add_child(
441457
JavascriptLink(
442458
"https://cdnjs.cloudflare.com/ajax/libs/vega-lite/1.3.1/vega-lite.js"
443459
),
444460
name="vega-lite",
445-
) # noqa
461+
)
446462
figure.header.add_child(
447463
JavascriptLink(
448464
"https://cdnjs.cloudflare.com/ajax/libs/vega-embed/2.2.0/vega-embed.js"
449465
),
450466
name="vega-embed",
451-
) # noqa
467+
)
452468

453469

454470
class GeoJson(Layer):
@@ -820,7 +836,7 @@ def _get_self_bounds(self) -> List[List[Optional[float]]]:
820836
"""
821837
return get_bounds(self.data, lonlat=True)
822838

823-
def render(self, **kwargs) -> None:
839+
def render(self, **kwargs):
824840
self.parent_map = get_obj_in_upper_tree(self, Map)
825841
# Need at least one feature, otherwise style mapping fails
826842
if (self.style or self.highlight) and self.data["features"]:
@@ -1041,12 +1057,12 @@ def recursive_get(data, keys):
10411057
self.style_function(feature)
10421058
) # noqa
10431059

1044-
def render(self, **kwargs) -> None:
1060+
def render(self, **kwargs):
10451061
"""Renders the HTML representation of the element."""
10461062
self.style_data()
10471063
super().render(**kwargs)
10481064

1049-
def get_bounds(self) -> List[List[float]]:
1065+
def get_bounds(self) -> TypeBoundsReturn:
10501066
"""
10511067
Computes the bounds of the object itself (not including it's children)
10521068
in the form [[lat_min, lon_min], [lat_max, lon_max]]
@@ -1146,6 +1162,7 @@ def __init__(
11461162

11471163
def warn_for_geometry_collections(self) -> None:
11481164
"""Checks for GeoJson GeometryCollection features to warn user about incompatibility."""
1165+
assert isinstance(self._parent, GeoJson)
11491166
geom_collections = [
11501167
feature.get("properties") if feature.get("properties") is not None else key
11511168
for key, feature in enumerate(self._parent.data["features"])
@@ -1160,7 +1177,7 @@ def warn_for_geometry_collections(self) -> None:
11601177
UserWarning,
11611178
)
11621179

1163-
def render(self, **kwargs) -> None:
1180+
def render(self, **kwargs):
11641181
"""Renders the HTML representation of the element."""
11651182
figure = self.get_root()
11661183
if isinstance(self._parent, GeoJson):
@@ -1565,7 +1582,7 @@ def __init__(
15651582
color_range = color_brewer(fill_color, n=nb_bins)
15661583
self.color_scale = StepColormap(
15671584
color_range,
1568-
index=bin_edges,
1585+
index=list(bin_edges),
15691586
vmin=bins_min,
15701587
vmax=bins_max,
15711588
caption=legend_name,
@@ -1625,7 +1642,7 @@ def highlight_function(x):
16251642
return {"weight": line_weight + 2, "fillOpacity": fill_opacity + 0.2}
16261643

16271644
if topojson:
1628-
self.geojson = TopoJson(
1645+
self.geojson: Union[TopoJson, GeoJson] = TopoJson(
16291646
geo_data,
16301647
topojson,
16311648
style_function=style_function,
@@ -1657,7 +1674,7 @@ def _get_by_key(cls, obj: Union[dict, list], key: str) -> Union[float, str, None
16571674
else:
16581675
return value
16591676

1660-
def render(self, **kwargs) -> None:
1677+
def render(self, **kwargs):
16611678
"""Render the GeoJson/TopoJson and color scale objects."""
16621679
if self.color_scale:
16631680
# ColorMap needs Map as its parent
@@ -1963,8 +1980,13 @@ def __init__(
19631980
vmin=min(colors),
19641981
vmax=max(colors),
19651982
).to_step(nb_steps)
1966-
else:
1983+
elif isinstance(colormap, StepColormap):
19671984
cm = colormap
1985+
else:
1986+
raise TypeError(
1987+
f"Unexpected type for argument `colormap`: {type(colormap)}"
1988+
)
1989+
19681990
out: Dict[str, List[List[List[float]]]] = {}
19691991
for (lat1, lng1), (lat2, lng2), color in zip(coords[:-1], coords[1:], colors):
19701992
out.setdefault(cm(color), []).append([[lat1, lng1], [lat2, lng2]])

Diff for: folium/folium.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -377,7 +377,7 @@ def _repr_png_(self) -> Optional[bytes]:
377377
return None
378378
return self._to_png()
379379

380-
def render(self, **kwargs) -> None:
380+
def render(self, **kwargs):
381381
"""Renders the HTML representation of the element."""
382382
figure = self.get_root()
383383
assert isinstance(

Diff for: folium/map.py

+7-6
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
import warnings
77
from collections import OrderedDict
8-
from typing import TYPE_CHECKING, List, Optional, Sequence, Union
8+
from typing import TYPE_CHECKING, Optional, Sequence, Union, cast
99

1010
from branca.element import Element, Figure, Html, MacroElement
1111

@@ -14,6 +14,7 @@
1414
from folium.utilities import (
1515
JsCode,
1616
TypeBounds,
17+
TypeBoundsReturn,
1718
TypeJsonValue,
1819
escape_backticks,
1920
parse_options,
@@ -221,7 +222,7 @@ def reset(self) -> None:
221222
self.base_layers = OrderedDict()
222223
self.overlays = OrderedDict()
223224

224-
def render(self, **kwargs) -> None:
225+
def render(self, **kwargs):
225226
"""Renders the HTML representation of the element."""
226227
self.reset()
227228
for item in self._parent._children.values():
@@ -396,15 +397,15 @@ def __init__(
396397
tooltip if isinstance(tooltip, Tooltip) else Tooltip(str(tooltip))
397398
)
398399

399-
def _get_self_bounds(self) -> List[List[float]]:
400+
def _get_self_bounds(self) -> TypeBoundsReturn:
400401
"""Computes the bounds of the object itself.
401402
402403
Because a marker has only single coordinates, we repeat them.
403404
"""
404405
assert self.location is not None
405-
return [self.location, self.location]
406+
return cast(TypeBoundsReturn, [self.location, self.location])
406407

407-
def render(self) -> None:
408+
def render(self):
408409
if self.location is None:
409410
raise ValueError(
410411
f"{self._name} location must be assigned when added directly to map."
@@ -492,7 +493,7 @@ def __init__(
492493
**kwargs,
493494
)
494495

495-
def render(self, **kwargs) -> None:
496+
def render(self, **kwargs):
496497
"""Renders the HTML representation of the element."""
497498
for name, child in self._children.items():
498499
child.render(**kwargs)

Diff for: folium/plugins/overlapping_marker_spiderfier.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ def add_to(
9292
) -> Element:
9393
self._parent = parent
9494
self.markers = self._get_all_markers(parent)
95-
super().add_to(parent, name=name, index=index)
95+
return super().add_to(parent, name=name, index=index)
9696

9797
def _get_all_markers(self, element: Element) -> list:
9898
markers = []

Diff for: folium/raster_layers.py

+9-7
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,11 @@
1212
from folium.template import Template
1313
from folium.utilities import (
1414
TypeBounds,
15+
TypeBoundsReturn,
1516
TypeJsonValue,
1617
image_to_url,
1718
mercator_transform,
19+
normalize_bounds_type,
1820
parse_options,
1921
remove_empty,
2022
)
@@ -246,7 +248,7 @@ class ImageOverlay(Layer):
246248
* If string, it will be written directly in the output file.
247249
* If file, it's content will be converted as embedded in the output file.
248250
* If array-like, it will be converted to PNG base64 string and embedded in the output.
249-
bounds: list
251+
bounds: list/tuple of list/tuple of float
250252
Image bounds on the map in the form
251253
[[lat_min, lon_min], [lat_max, lon_max]]
252254
opacity: float, default Leaflet's default (1.0)
@@ -319,7 +321,7 @@ def __init__(
319321

320322
self.url = image_to_url(image, origin=origin, colormap=colormap)
321323

322-
def render(self, **kwargs) -> None:
324+
def render(self, **kwargs):
323325
super().render()
324326

325327
figure = self.get_root()
@@ -344,13 +346,13 @@ def render(self, **kwargs) -> None:
344346
Element(pixelated), name="leaflet-image-layer"
345347
) # noqa
346348

347-
def _get_self_bounds(self) -> TypeBounds:
349+
def _get_self_bounds(self) -> TypeBoundsReturn:
348350
"""
349351
Computes the bounds of the object itself (not including it's children)
350352
in the form [[lat_min, lon_min], [lat_max, lon_max]].
351353
352354
"""
353-
return self.bounds
355+
return normalize_bounds_type(self.bounds)
354356

355357

356358
class VideoOverlay(Layer):
@@ -361,7 +363,7 @@ class VideoOverlay(Layer):
361363
----------
362364
video_url: str
363365
URL of the video
364-
bounds: list
366+
bounds: list/tuple of list/tuple of float
365367
Video bounds on the map in the form
366368
[[lat_min, lon_min], [lat_max, lon_max]]
367369
autoplay: bool, default True
@@ -411,10 +413,10 @@ def __init__(
411413
self.bounds = bounds
412414
self.options = remove_empty(autoplay=autoplay, loop=loop, **kwargs)
413415

414-
def _get_self_bounds(self) -> TypeBounds:
416+
def _get_self_bounds(self) -> TypeBoundsReturn:
415417
"""
416418
Computes the bounds of the object itself (not including it's children)
417419
in the form [[lat_min, lon_min], [lat_max, lon_max]]
418420
419421
"""
420-
return self.bounds
422+
return normalize_bounds_type(self.bounds)

0 commit comments

Comments
 (0)