Skip to content

Commit

Permalink
Add files via upload
Browse files Browse the repository at this point in the history
  • Loading branch information
heykeetae authored May 28, 2018
1 parent f097a20 commit 4cae04f
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 43 deletions.
2 changes: 0 additions & 2 deletions main.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@

from parameter import *
from trainer import Trainer
from qgan_trainer import Trainer as qgan_trainer
# from tester import Tester
from data_loader import Data_Loader
from torch.backends import cudnn
Expand All @@ -18,7 +17,6 @@ def main(config):

# Create directories if not exist
make_folder(config.model_save_path, config.version)
make_folder(config.result_path, config.version)
make_folder(config.sample_path, config.version)
make_folder(config.log_path, config.version)
make_folder(config.attn_path, config.version)
Expand Down
28 changes: 2 additions & 26 deletions parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ def get_parameters():
parser.add_argument('--g_conv_dim', type=int, default=64)
parser.add_argument('--d_conv_dim', type=int, default=64)
parser.add_argument('--lambda_gp', type=float, default=10)
parser.add_argument('--version', type=str, default='sagan_1')

# Training setting
parser.add_argument('--total_step', type=int, default=1000000, help='how many times to update the generator')
Expand All @@ -31,33 +32,12 @@ def get_parameters():
# using pretrained
parser.add_argument('--pretrained_model', type=int, default=None)

# gating net
parser.add_argument('--gum_orig', type=float, default=1) # gum start temperature
parser.add_argument('--gum_temp', type=float, default=1)
parser.add_argument('--min_temp', type=float, default=0.01)
parser.add_argument('--gum_temp_decay', type=float, default=0.0001)
parser.add_argument('--step_anneal', type=int, default=1) # epoch to apply decaying
parser.add_argument('--start_anneal', type=int, default=0) # epoch to start annealing


# Test setting
parser.add_argument('--test_size', type=int, default=64)
parser.add_argument('--test_model', type=str, default='50000_G.pth')
parser.add_argument('--result_path', type=str, default='./results')
parser.add_argument('--version', type=str, default='Gum')
parser.add_argument('--nrow', type=int, default=8)
parser.add_argument('--ncol', type=int, default=8)

# Misc
parser.add_argument('--train', type=str2bool, default=True)
parser.add_argument('--parallel', type=str2bool, default=False)
parser.add_argument('--dataset', type=str, default='cifar', choices=['lsun', 'celeb', ])
parser.add_argument('--dataset', type=str, default='cifar', choices=['lsun', 'celeb'])
parser.add_argument('--use_tensorboard', type=str2bool, default=False)

# Load balance
parser.add_argument('--load_balance_on', type=str2bool, default=False)
parser.add_argument('--load_weight', type=float, default=1.0) # for 2, for 5 1000, for 4500

# Path
parser.add_argument('--image_path', type=str, default='./data')
parser.add_argument('--log_path', type=str, default='./logs')
Expand All @@ -70,9 +50,5 @@ def get_parameters():
parser.add_argument('--sample_step', type=int, default=100)
parser.add_argument('--model_save_step', type=float, default=1.0)

# claculating quantitative measures
parser.add_argument('--score_epoch', type=int, default=3) # = 5 epochs
parser.add_argument('--score_start', type=int, default=3) # start at 5 (default)


return parser.parse_args()
17 changes: 2 additions & 15 deletions trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,30 +40,17 @@ def __init__(self, data_loader, config):
self.beta1 = config.beta1
self.beta2 = config.beta2
self.pretrained_model = config.pretrained_model
self.gum_orig = config.gum_orig
self.gum_temp = config.gum_temp
self.min_temp = config.min_temp
self.gum_temp_decay = config.gum_temp_decay
self.step_anneal = config.step_anneal
self.start_anneal = config.start_anneal
self.test_model = config.test_model
self.result_path = config.result_path
self.version = config.version
self.nrow = config.nrow
self.ncol = config.ncol

self.dataset = config.dataset
self.use_tensorboard = config.use_tensorboard
self.load_balance_on = config.load_balance_on
self.load_weight = config.load_weight
self.image_path = config.image_path
self.log_path = config.log_path
self.model_save_path = config.model_save_path
self.sample_path = config.sample_path
self.log_step = config.log_step
self.sample_step = config.sample_step
self.model_save_step = config.model_save_step
self.score_epoch = config.score_epoch
self.score_start = config.score_start
self.version = config.version

# Path
self.log_path = os.path.join(config.log_path, self.version)
Expand Down

0 comments on commit 4cae04f

Please sign in to comment.