|
21 | 21 | from dataclasses import dataclass
|
22 | 22 | from datetime import date, datetime, time
|
23 | 23 | from functools import cached_property, singledispatch
|
24 |
| -from typing import Annotated, Any, Dict, Generic, List, Optional, Tuple, TypeVar, Union |
| 24 | +from typing import Annotated, Any, Callable, Dict, Generic, List, Optional, Tuple, TypeVar, Union |
25 | 25 | from urllib.parse import quote_plus
|
26 | 26 |
|
27 | 27 | from pydantic import (
|
@@ -272,6 +272,60 @@ def assign_fresh_partition_spec_ids(spec: PartitionSpec, old_schema: Schema, fre
|
272 | 272 | T = TypeVar("T")
|
273 | 273 |
|
274 | 274 |
|
| 275 | +class PartitionMap(Generic[T]): |
| 276 | + _specs: dict[int, PartitionSpec] |
| 277 | + _partition_maps: dict[int, dict[Record | None, T]] |
| 278 | + |
| 279 | + def __init__(self, specs: dict[int, PartitionSpec]): |
| 280 | + self._specs = specs |
| 281 | + self._partition_maps = {} |
| 282 | + |
| 283 | + def __len__(self) -> int: |
| 284 | + """Return the length of the partition map. |
| 285 | +
|
| 286 | + Returns: |
| 287 | + length of _partition_maps |
| 288 | + """ |
| 289 | + return len(self.values()) |
| 290 | + |
| 291 | + def is_empty(self) -> bool: |
| 292 | + return len(self.values()) == 0 |
| 293 | + |
| 294 | + def contains_key(self, spec_id: int, struct: Record) -> bool: |
| 295 | + return struct in self._partition_maps.get(spec_id, {}) |
| 296 | + |
| 297 | + def contains_value(self, value: T) -> bool: |
| 298 | + return value in self.values() |
| 299 | + |
| 300 | + def get(self, spec_id: int, struct: Record | None) -> Optional[T]: |
| 301 | + if partition_map := self._partition_maps.get(spec_id): |
| 302 | + if result := partition_map.get(struct): |
| 303 | + return result |
| 304 | + return None |
| 305 | + |
| 306 | + def put(self, spec_id: int, struct: Record | None, value: T) -> None: |
| 307 | + if _ := self._specs.get(spec_id): |
| 308 | + if spec_id not in self._partition_maps: |
| 309 | + self._partition_maps[spec_id] = {struct: value} |
| 310 | + else: |
| 311 | + self._partition_maps[spec_id][struct] = value |
| 312 | + |
| 313 | + def compute_if_absent(self, spec_id: int, struct: Record, value_factory: Callable[[], T]) -> T: |
| 314 | + partition_map = self._partition_maps.setdefault(spec_id, {}) |
| 315 | + if struct in partition_map: |
| 316 | + return partition_map[struct] |
| 317 | + |
| 318 | + value = value_factory() |
| 319 | + partition_map[struct] = value |
| 320 | + return value |
| 321 | + |
| 322 | + def values(self) -> list[T]: |
| 323 | + result: list[T] = [] |
| 324 | + for partition_map in self._partition_maps.values(): |
| 325 | + result.extend(partition_map.values()) |
| 326 | + return result |
| 327 | + |
| 328 | + |
275 | 329 | class PartitionSpecVisitor(Generic[T], ABC):
|
276 | 330 | @abstractmethod
|
277 | 331 | def identity(self, field_id: int, source_name: str, source_id: int) -> T:
|
|
0 commit comments