Skip to content

Commit

Permalink
formatting and info
Browse files Browse the repository at this point in the history
  • Loading branch information
rromb committed Jul 26, 2022
1 parent 7cc4a39 commit 513f009
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 8 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -100,12 +100,12 @@ python scripts/train_searcher.py

Retrieval based text-guided sampling with visual nearest neighbors can be started via
```
python scripts/knn2img.py --prompt "a happy bear reading a newspaper" --use_neighbors --knn <number_of_neighbors>
python scripts/knn2img.py --prompt "a happy pineapple" --use_neighbors --knn <number_of_neighbors>
```
Note that the maximum supported number of neighbors is 20.
The database can be changed via the cmd parameter ``--database`` which can be `[openimages, artbench-art_nouveau, artbench-baroque, artbench-expressionism, artbench-impressionism, artbench-post_impressionism, artbench-realism, artbench-renaissance, artbench-romanticism, artbench-surrealism, artbench-ukiyo_e]`.
For using `--database openimages`, the above script (`scripts/train_searcher.py`) must be executed before.

Due to their relatively small size, the artbench datasetbases are best suited for creating more abstract concepts and do not work well for detailed text control.


#### Coming Soon
Expand Down
13 changes: 7 additions & 6 deletions ldm/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def ismap(x):


def isimage(x):
if not isinstance(x,torch.Tensor):
if not isinstance(x, torch.Tensor):
return False
return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1)

Expand All @@ -71,7 +71,7 @@ def mean_flat(tensor):
def count_params(model, verbose=False):
total_params = sum(p.numel() for p in model.parameters())
if verbose:
print(f"{model.__class__.__name__} has {total_params*1.e-6:.2f} M params.")
print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.")
return total_params


Expand All @@ -92,20 +92,21 @@ def get_obj_from_str(string, reload=False):
importlib.reload(module_imp)
return getattr(importlib.import_module(module, package=None), cls)


def _do_parallel_data_prefetch(func, Q, data, idx, idx_to_fn=False):
# create dummy dataset instance

# run prefetching
if idx_to_fn:
res = func(data,worker_id=idx)
res = func(data, worker_id=idx)
else:
res = func(data)
Q.put([idx, res])
Q.put("Done")


def parallel_data_prefetch(
func: callable, data, n_proc, target_data_type="ndarray",cpu_intensive=True,use_worker_id=False
func: callable, data, n_proc, target_data_type="ndarray", cpu_intensive=True, use_worker_id=False
):
# if target_data_type not in ["ndarray", "list"]:
# raise ValueError(
Expand Down Expand Up @@ -149,7 +150,7 @@ def parallel_data_prefetch(
arguments = [
[func, Q, part, i, use_worker_id]
for i, part in enumerate(
[data[i : i + step] for i in range(0, len(data), step)]
[data[i: i + step] for i in range(0, len(data), step)]
)
]
processes = []
Expand Down Expand Up @@ -199,4 +200,4 @@ def parallel_data_prefetch(
out.extend(r)
return out
else:
return gather_res
return gather_res

0 comments on commit 513f009

Please sign in to comment.