forked from UMich-CURLY/LatentBKI
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add PCA experiment and generate_dataset.py
- Loading branch information
1 parent
eac7389
commit bdb5ea8
Showing
6 changed files
with
384 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
import os | ||
import cv2 | ||
import numpy as np | ||
import warnings | ||
from torch.utils.data import Dataset | ||
import yaml | ||
|
||
with warnings.catch_warnings(): | ||
warnings.filterwarnings("ignore", category=DeprecationWarning) | ||
|
||
config_file = os.path.join(os.getcwd(), "ipca.yaml") | ||
config = yaml.safe_load(open(config_file, 'r')) | ||
|
||
class HabitatDataset(Dataset): | ||
"""Dataset class for mp3d habitat.""" | ||
|
||
def __init__(self, dataset_dir, data_split='train'): | ||
"""Read in the data from disk.""" | ||
super().__init__() | ||
|
||
if data_split not in ["train", "val", "all"]: | ||
raise ValueError("Partition {} does not exist".format(data_split)) | ||
|
||
self.split = data_split | ||
if data_split == "all": | ||
self.seq_names = ( | ||
config['split']['train'] + config['split']['val'] | ||
) | ||
else: | ||
self.seq_names = config['split'][self.split] | ||
|
||
self.seq_dirs = [os.path.join(dataset_dir, sequence) for sequence in self.seq_names] | ||
self.seqs = [HabitatFrameDataset(seq_dir, self.seq_names[i]) for i, seq_dir in enumerate(self.seq_dirs)] | ||
|
||
def __len__(self): | ||
"""Return size of dataset.""" | ||
return len(self.seqs) | ||
|
||
def __getitem__(self, idx): | ||
"""Return sample at index `idx` of dataset.""" | ||
return self.seqs[idx] | ||
|
||
|
||
class HabitatFrameDataset(Dataset): | ||
"""Dataset class for sequences of scenes.""" | ||
|
||
def __init__(self, seq_dir, seq_name): | ||
"""Read data from disk for a single scene.""" | ||
super().__init__() | ||
# Define paths to the dir | ||
self.seq_name = seq_name | ||
self.dir = seq_dir | ||
self.rgb_dir = os.path.join(seq_dir, 'rgb') | ||
self.depth_dir = os.path.join(seq_dir, 'depth') | ||
self.semantic_dir = os.path.join(seq_dir, 'semantic') | ||
self.pose_file = os.path.join(seq_dir, 'poses.txt') | ||
# Define paths to the scene | ||
self.rgb_files = [os.path.join(self.rgb_dir, file) for file in sorted(os.listdir(self.rgb_dir))] | ||
self.depth_files = [os.path.join(self.depth_dir, file) for file in sorted(os.listdir(self.depth_dir))] | ||
self.semantic_files = [os.path.join(self.semantic_dir, file) | ||
for file in sorted(os.listdir(self.semantic_dir))] | ||
self.poses = [line for line in np.loadtxt(self.pose_file)] | ||
|
||
def __len__(self): | ||
"""Return the number of frames in the scene.""" | ||
return len(self.rgb_files) | ||
|
||
def __getitem__(self, idx): | ||
"""Return data for a single frame.""" | ||
# Load the RGB image | ||
bgr = cv2.imread(self.rgb_files[idx]) | ||
rgb = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB) | ||
# Load the depth | ||
with open(self.depth_files[idx], "rb") as f: | ||
depth = np.load(f) | ||
# Load the ground-truth semantic map | ||
semantic = np.load(self.semantic_files[idx]) | ||
# Get the pose | ||
pose = self.poses[idx] | ||
frame = {'rgb': rgb, 'depth': depth, 'pose': pose, 'semantic': semantic} | ||
return frame | ||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
mp3d_data_dir: '/workspace/sda1/vlmaps_data_dir/vlmaps_dataset' | ||
target_dimension: '64' | ||
image_height: 720 | ||
image_width: 1080 | ||
save_dir: '/workspace/sda1/LatentBKI_cleanup/PCAExperiment/predictions' | ||
split: | ||
train: | ||
- jh4fc5c5qoQ_1 | ||
# - JmbYfDe2QKZ_1 | ||
# - JmbYfDe2QKZ_2 | ||
# - mJXqzFtmKg4_1 | ||
# - ur6pFq6Qu1A_1 | ||
# - UwV83HsGsw3_1 | ||
# - Vt2qJdWjCF2_1 | ||
# - YmJkqBEsHnH_1 | ||
val: | ||
- gTV8FGcVJC9_1 | ||
- 5LpN3gDmAk7_1 | ||
labels: | ||
- void | ||
- wall | ||
- floor | ||
- chair | ||
- door | ||
- table | ||
- picture | ||
- cabinet | ||
- cushion | ||
- window | ||
- sofa | ||
- bed | ||
- curtain | ||
- chest_of_drawers | ||
- plant | ||
- sink | ||
- stairs | ||
- ceiling | ||
- toilet | ||
- stool | ||
- towel | ||
- mirror | ||
- tv_monitor | ||
- shower | ||
- column | ||
- bathtub | ||
- counter | ||
- fireplace | ||
- lighting | ||
- beam | ||
- railing | ||
- shelving | ||
- blinds | ||
- gym_equipment | ||
- seating | ||
- board_panel | ||
- furniture | ||
- appliances | ||
- clothes | ||
- objects | ||
- misc |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
import os | ||
import yaml | ||
from ipca_utils import get_miou, get_accuracy | ||
|
||
def main(): | ||
config_file = os.path.join(os.getcwd(), "ipca.yaml") | ||
with open(config_file, "r") as stream: | ||
try: | ||
config = yaml.safe_load(stream) | ||
target_dimension = config["target_dimension"] | ||
val_seqs = config["split"]["val"] | ||
labels = config["labels"] | ||
data_dir = config["mp3d_data_dir"] | ||
save_dir = config["save_dir"] | ||
except yaml.YAMLError as exc: | ||
print(exc) | ||
num_classes = len(labels) | ||
result_file_path = os.path.join(os.getcwd(), "predictions", target_dimension, 'result.txt') | ||
mean_iou = get_miou(num_classes, val_seqs, data_dir, save_dir, target_dimension) | ||
accuracy = get_accuracy(val_seqs, data_dir, save_dir, target_dimension) | ||
print(f"mean iou: {mean_iou}\n") | ||
print(f"accuracy: {accuracy}\n") | ||
with open(result_file_path, mode='a') as file: | ||
file.write(f"mean iou: {mean_iou}\n") | ||
file.write(f"accuracy: {accuracy}\n") | ||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,109 @@ | ||
import os | ||
import sys | ||
sys.path.append("..") | ||
|
||
import clip | ||
import torch | ||
import torch.nn.functional as F | ||
from tqdm import tqdm | ||
import numpy as np | ||
import pickle | ||
import yaml | ||
|
||
from PCAonGPU.gpu_pca import IncrementalPCAonGPU | ||
from ipca_utils import get_sim_mat, get_flattened_data | ||
from dataset import HabitatDataset | ||
from Models.Lseg.lseg_utils import init_lseg, get_lseg_feat | ||
|
||
def main(): | ||
config_file = os.path.join(os.getcwd(), "ipca.yaml") | ||
with open(config_file, "r") as stream: | ||
try: | ||
config = yaml.safe_load(stream) | ||
target_dimension = config["target_dimension"] | ||
labels = config["labels"] | ||
h = config["image_height"] | ||
w = config["image_width"] | ||
save_dir = config["save_dir"] | ||
data_dir = config["mp3d_data_dir"] | ||
except yaml.YAMLError as exc: | ||
print(exc) | ||
|
||
device = "cuda" if torch.cuda.is_available() else "cpu" | ||
model, _ = clip.load("ViT-B/32", device) | ||
|
||
save_dir = os.path.join(save_dir, target_dimension) | ||
os.makedirs(save_dir, exist_ok=True) | ||
|
||
ipca = IncrementalPCAonGPU(n_components=int(target_dimension)) | ||
|
||
train_sequences = HabitatDataset(dataset_dir=data_dir, data_split="train") | ||
val_sequences = HabitatDataset(dataset_dir=data_dir, data_split="val") | ||
# init lseg model | ||
(lseg_model, lseg_transform, crop_size, | ||
base_size, norm_mean, norm_std, clip_feat_dim) = init_lseg(device) | ||
|
||
# train | ||
for seq_i, sequence in enumerate(tqdm(train_sequences, desc="Training")): | ||
# if pca file exists | ||
pca_save_path = os.path.join(save_dir, f'ipca_{seq_i}.pkl') | ||
if os.path.exists(pca_save_path): | ||
print(pca_save_path, "exists, continue") | ||
with open(pca_save_path, 'rb') as file: | ||
ipca = pickle.load(file) | ||
print(f'load ipca_{seq_i}.pkl') | ||
continue | ||
|
||
for frame_i, frame in enumerate(tqdm(sequence, desc=sequence.seq_name)): | ||
pix_feats = get_lseg_feat( | ||
lseg_model, frame['rgb'], labels, | ||
lseg_transform, device, crop_size, | ||
base_size, norm_mean, norm_std, vis=False | ||
) | ||
flattened_data = get_flattened_data(pix_feats, device) | ||
ipca.partial_fit(flattened_data) | ||
|
||
with open(pca_save_path, 'wb') as file: | ||
pickle.dump(ipca, file) | ||
print(f'ipca_{seq_i}.pkl saved.') | ||
|
||
label_token = clip.tokenize(labels).to(device) | ||
with torch.no_grad(): | ||
label_features = model.encode_text(label_token) | ||
|
||
# validation | ||
for seq_i, sequence in enumerate(tqdm(val_sequences, desc="Validation")): | ||
|
||
for frame_i, frame in enumerate(tqdm(sequence, desc=sequence.seq_name)): | ||
pred_dir = os.path.join(save_dir, f'{sequence.seq_name}_pred') | ||
os.makedirs(pred_dir, exist_ok=True) | ||
|
||
pred_path = os.path.join(pred_dir, f'{frame_i:06d}.npy') | ||
if os.path.exists(pred_path): | ||
continue | ||
|
||
pix_feats = get_lseg_feat( | ||
lseg_model, frame['rgb'], labels, | ||
lseg_transform, device, crop_size, | ||
base_size, norm_mean, norm_std, vis=False | ||
) | ||
flattened_data = get_flattened_data(pix_feats, device) | ||
|
||
pca_data = ipca.transform(flattened_data) | ||
|
||
# back project to 512 dimension | ||
bp_data = ipca.inverse_transform(pca_data) | ||
|
||
similarity_matrix = get_sim_mat(1000, bp_data, label_features) | ||
|
||
prediction_probs = F.softmax(similarity_matrix, dim=1) | ||
predictions = torch.argmax(prediction_probs, dim=1) | ||
predictions = predictions.reshape(h, w) | ||
|
||
# save predictions | ||
predictions = predictions.to('cpu').numpy() | ||
np.save(pred_path, predictions) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
import os | ||
import torch | ||
import numpy as np | ||
import torch.nn.functional as F | ||
|
||
def get_sim_mat(batch_size, bp_data, label_features): | ||
num_batches = (bp_data.size(0) + batch_size - 1) // batch_size | ||
similarity_matrices = [] | ||
for i in range(num_batches): | ||
start_idx = i * batch_size | ||
end_idx = min(start_idx + batch_size, bp_data.size(0)) | ||
batch_bp_data = bp_data[start_idx:end_idx] | ||
similarity_matrix = F.cosine_similarity(batch_bp_data.unsqueeze(1), label_features.unsqueeze(0), dim=2) | ||
similarity_matrices.append(similarity_matrix) | ||
|
||
similarity_matrix = torch.cat(similarity_matrices, dim=0) | ||
return similarity_matrix | ||
|
||
|
||
def get_flattened_data(pix_feats, device): | ||
data = pix_feats | ||
data = data.squeeze(0) | ||
flattened_data = data.reshape(data.shape[0], -1) | ||
flattened_data = flattened_data.T | ||
flattened_data = torch.tensor(flattened_data).to(device) | ||
return flattened_data | ||
|
||
|
||
def get_miou(num_classes, val_seqs, data_dir, save_dir, target_dimension): | ||
iou_list = [] | ||
for cls in range(num_classes): | ||
intersection = 0 | ||
union = 0 | ||
|
||
for seq in val_seqs: | ||
|
||
gt_dir = os.path.join(data_dir, f'{seq}/semantic') | ||
gt_files = sorted([f for f in os.listdir(gt_dir) if f.endswith('.npy')]) | ||
|
||
for i, gt_file in enumerate(gt_files): | ||
|
||
gt_pred = np.load(os.path.join(gt_dir, gt_file)) | ||
|
||
pca_pred = np.load(os.path.join(save_dir, target_dimension, | ||
f'{seq}_pred', gt_file)) | ||
gt_mask = (gt_pred == cls) | ||
pred_mask = (pca_pred == cls) | ||
|
||
# Calculate intersection and union | ||
intersection += np.logical_and(gt_mask, pred_mask).sum() | ||
union += np.logical_or(gt_mask, pred_mask).sum() | ||
|
||
iou = intersection / (union + +1e-6) | ||
|
||
iou_list.append(iou) | ||
print("class: ", cls) | ||
|
||
# Calculate mean IoU | ||
return np.mean(iou_list) | ||
|
||
def get_accuracy(val_seqs, data_dir, save_dir, target_dimension): | ||
correct_predictions = 0 | ||
total_predictions = 0 | ||
for seq in val_seqs: | ||
|
||
gt_dir = os.path.join(data_dir, f'{seq}/semantic') | ||
gt_files = sorted([f for f in os.listdir(gt_dir) if f.endswith('.npy')]) | ||
|
||
for i, gt_file in enumerate(gt_files): | ||
|
||
gt_pred = np.load(os.path.join(gt_dir, gt_file)) | ||
|
||
pca_pred = np.load(os.path.join(save_dir, target_dimension, | ||
f'{seq}_pred', gt_file)) | ||
|
||
correct_predictions += (gt_pred == pca_pred).sum() | ||
total_predictions += gt_pred.size | ||
|
||
return correct_predictions / total_predictions | ||
|
||
|
||
|
||
|
||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters