Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12,692 changes: 12,692 additions & 0 deletions experiments.ipynb

Large diffs are not rendered by default.

44 changes: 44 additions & 0 deletions heuristics/model/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,49 @@
from typing import Dict, Optional

import pandas as pd
from avito_ds_swat_utils.storages.dwh import VerticaPandasConnection
from avito_ds_swat_utils.utils.image_manager import ImageManager
from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms

from .settings import (
IMAGES_DIR,
LABELS_PATH,
MAX_IMG_SIZE_STR,
)


def get_pd_dataset(cache_path=LABELS_PATH):
if cache_path is None or not os.path.exists(cache_path):
with VerticaPandasConnection() as dwh:
df = dwh.sql_to_pd('select * from dsswat.datasets_course_room_type')

df = df.dropna(subset=['image_id_ext'])
df['image_id_ext'] = df['image_id_ext'].astype(int)
df.reset_index(inplace=True)
if cache_path is not None:
df.to_csv(cache_path, index=False)
else:
df = pd.read_csv(cache_path)

return df


def load_images(image_ids, save_dir=IMAGES_DIR, img_size=MAX_IMG_SIZE_STR, image_save_shards=False):
image_manager = ImageManager()

image_manager.process_images(
image_id_list=image_ids,
schema='item',
size=img_size,
version=1,
private=True,
image_save_to_disk=True,
image_root_dir=save_dir,
image_save_shards=image_save_shards,
)


class TorchDataset(Dataset):
image_id_column: str = 'image_id_ext'
Expand All @@ -20,6 +59,7 @@ def __init__(
image_dir: str = None,
transformer: Optional[transforms.transforms.Compose] = None,
normalize_sample_weights: bool = True,
custom_transform : Optional[transforms.transforms.Compose] = None,
):
super(TorchDataset, self).__init__()
self.image_dir = image_dir
Expand All @@ -30,6 +70,7 @@ def __init__(
# self.df = self.df[self.df['result'] != -1]
self.transformer = transformer
self.image_cache = {}
self.custom_transform = custom_transform

def img_path_from_id(self, image_id):
path = os.path.join(self.image_dir, f'{image_id}.{self.img_extension}')
Expand Down Expand Up @@ -67,6 +108,9 @@ def __getitem__(self, idx):

self.image_cache[idx] = img

if self.custom_transform:
img = self.custom_transform(img)

# item = {
# 'img': img,
# 'label': item_series['result'],
Expand Down