Complete lab2
This commit is contained in:
parent
976b1adc30
commit
65d6d65b0f
1
.gitignore
vendored
1
.gitignore
vendored
@ -1 +1,2 @@
|
|||||||
venv
|
venv
|
||||||
|
__pycache__
|
445
Lab2/examples/Lab31-mnist.ipynb
Normal file
445
Lab2/examples/Lab31-mnist.ipynb
Normal file
File diff suppressed because one or more lines are too long
1099
Lab2/main.ipynb
1099
Lab2/main.ipynb
File diff suppressed because one or more lines are too long
508
Lab2/pix2pix.py
Normal file
508
Lab2/pix2pix.py
Normal file
@ -0,0 +1,508 @@
|
|||||||
|
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""Pix2pix.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
from absl import app
|
||||||
|
from absl import flags
|
||||||
|
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
FLAGS = flags.FLAGS
|
||||||
|
|
||||||
|
flags.DEFINE_integer('buffer_size', 400, 'Shuffle buffer size')
|
||||||
|
flags.DEFINE_integer('batch_size', 1, 'Batch Size')
|
||||||
|
flags.DEFINE_integer('epochs', 1, 'Number of epochs')
|
||||||
|
flags.DEFINE_string('path', None, 'Path to the data folder')
|
||||||
|
flags.DEFINE_boolean('enable_function', True, 'Enable Function?')
|
||||||
|
|
||||||
|
IMG_WIDTH = 256
|
||||||
|
IMG_HEIGHT = 256
|
||||||
|
AUTOTUNE = tf.data.experimental.AUTOTUNE
|
||||||
|
|
||||||
|
|
||||||
|
def load(image_file):
|
||||||
|
"""Loads the image and generates input and target image.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image_file: .jpeg file
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Input image, target image
|
||||||
|
"""
|
||||||
|
image = tf.io.read_file(image_file)
|
||||||
|
image = tf.image.decode_jpeg(image)
|
||||||
|
|
||||||
|
w = tf.shape(image)[1]
|
||||||
|
|
||||||
|
w = w // 2
|
||||||
|
real_image = image[:, :w, :]
|
||||||
|
input_image = image[:, w:, :]
|
||||||
|
|
||||||
|
input_image = tf.cast(input_image, tf.float32)
|
||||||
|
real_image = tf.cast(real_image, tf.float32)
|
||||||
|
|
||||||
|
return input_image, real_image
|
||||||
|
|
||||||
|
|
||||||
|
def resize(input_image, real_image, height, width):
|
||||||
|
input_image = tf.image.resize(input_image, [height, width],
|
||||||
|
method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
|
||||||
|
real_image = tf.image.resize(real_image, [height, width],
|
||||||
|
method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
|
||||||
|
|
||||||
|
return input_image, real_image
|
||||||
|
|
||||||
|
|
||||||
|
def random_crop(input_image, real_image):
|
||||||
|
stacked_image = tf.stack([input_image, real_image], axis=0)
|
||||||
|
cropped_image = tf.image.random_crop(
|
||||||
|
stacked_image, size=[2, IMG_HEIGHT, IMG_WIDTH, 3])
|
||||||
|
|
||||||
|
return cropped_image[0], cropped_image[1]
|
||||||
|
|
||||||
|
|
||||||
|
def normalize(input_image, real_image):
|
||||||
|
input_image = (input_image / 127.5) - 1
|
||||||
|
real_image = (real_image / 127.5) - 1
|
||||||
|
|
||||||
|
return input_image, real_image
|
||||||
|
|
||||||
|
|
||||||
|
@tf.function
|
||||||
|
def random_jitter(input_image, real_image):
|
||||||
|
"""Random jittering.
|
||||||
|
|
||||||
|
Resizes to 286 x 286 and then randomly crops to IMG_HEIGHT x IMG_WIDTH.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_image: Input Image
|
||||||
|
real_image: Real Image
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Input Image, real image
|
||||||
|
"""
|
||||||
|
# resizing to 286 x 286 x 3
|
||||||
|
input_image, real_image = resize(input_image, real_image, 286, 286)
|
||||||
|
|
||||||
|
# randomly cropping to 256 x 256 x 3
|
||||||
|
input_image, real_image = random_crop(input_image, real_image)
|
||||||
|
|
||||||
|
if tf.random.uniform(()) > 0.5:
|
||||||
|
# random mirroring
|
||||||
|
input_image = tf.image.flip_left_right(input_image)
|
||||||
|
real_image = tf.image.flip_left_right(real_image)
|
||||||
|
|
||||||
|
return input_image, real_image
|
||||||
|
|
||||||
|
|
||||||
|
def load_image_train(image_file):
|
||||||
|
input_image, real_image = load(image_file)
|
||||||
|
input_image, real_image = random_jitter(input_image, real_image)
|
||||||
|
input_image, real_image = normalize(input_image, real_image)
|
||||||
|
|
||||||
|
return input_image, real_image
|
||||||
|
|
||||||
|
|
||||||
|
def load_image_test(image_file):
|
||||||
|
input_image, real_image = load(image_file)
|
||||||
|
input_image, real_image = resize(input_image, real_image,
|
||||||
|
IMG_HEIGHT, IMG_WIDTH)
|
||||||
|
input_image, real_image = normalize(input_image, real_image)
|
||||||
|
|
||||||
|
return input_image, real_image
|
||||||
|
|
||||||
|
|
||||||
|
def create_dataset(path_to_train_images, path_to_test_images, buffer_size,
|
||||||
|
batch_size):
|
||||||
|
"""Creates a tf.data Dataset.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path_to_train_images: Path to train images folder.
|
||||||
|
path_to_test_images: Path to test images folder.
|
||||||
|
buffer_size: Shuffle buffer size.
|
||||||
|
batch_size: Batch size
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
train dataset, test dataset
|
||||||
|
"""
|
||||||
|
train_dataset = tf.data.Dataset.list_files(path_to_train_images)
|
||||||
|
train_dataset = train_dataset.shuffle(buffer_size)
|
||||||
|
train_dataset = train_dataset.map(
|
||||||
|
load_image_train, num_parallel_calls=AUTOTUNE)
|
||||||
|
train_dataset = train_dataset.batch(batch_size)
|
||||||
|
|
||||||
|
test_dataset = tf.data.Dataset.list_files(path_to_test_images)
|
||||||
|
test_dataset = test_dataset.map(
|
||||||
|
load_image_test, num_parallel_calls=AUTOTUNE)
|
||||||
|
test_dataset = test_dataset.batch(batch_size)
|
||||||
|
|
||||||
|
return train_dataset, test_dataset
|
||||||
|
|
||||||
|
|
||||||
|
class InstanceNormalization(tf.keras.layers.Layer):
|
||||||
|
"""Instance Normalization Layer (https://arxiv.org/abs/1607.08022)."""
|
||||||
|
|
||||||
|
def __init__(self, epsilon=1e-5):
|
||||||
|
super(InstanceNormalization, self).__init__()
|
||||||
|
self.epsilon = epsilon
|
||||||
|
|
||||||
|
def build(self, input_shape):
|
||||||
|
self.scale = self.add_weight(
|
||||||
|
name='scale',
|
||||||
|
shape=input_shape[-1:],
|
||||||
|
initializer=tf.random_normal_initializer(1., 0.02),
|
||||||
|
trainable=True)
|
||||||
|
|
||||||
|
self.offset = self.add_weight(
|
||||||
|
name='offset',
|
||||||
|
shape=input_shape[-1:],
|
||||||
|
initializer='zeros',
|
||||||
|
trainable=True)
|
||||||
|
|
||||||
|
def call(self, x):
|
||||||
|
mean, variance = tf.nn.moments(x, axes=[1, 2], keepdims=True)
|
||||||
|
inv = tf.math.rsqrt(variance + self.epsilon)
|
||||||
|
normalized = (x - mean) * inv
|
||||||
|
return self.scale * normalized + self.offset
|
||||||
|
|
||||||
|
|
||||||
|
def downsample(filters, size, norm_type='batchnorm', apply_norm=True):
|
||||||
|
"""Downsamples an input.
|
||||||
|
|
||||||
|
Conv2D => Batchnorm => LeakyRelu
|
||||||
|
|
||||||
|
Args:
|
||||||
|
filters: number of filters
|
||||||
|
size: filter size
|
||||||
|
norm_type: Normalization type; either 'batchnorm' or 'instancenorm'.
|
||||||
|
apply_norm: If True, adds the batchnorm layer
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Downsample Sequential Model
|
||||||
|
"""
|
||||||
|
initializer = tf.random_normal_initializer(0., 0.02)
|
||||||
|
|
||||||
|
result = tf.keras.Sequential()
|
||||||
|
result.add(
|
||||||
|
tf.keras.layers.Conv2D(filters, size, strides=2, padding='same',
|
||||||
|
kernel_initializer=initializer, use_bias=False))
|
||||||
|
|
||||||
|
if apply_norm:
|
||||||
|
if norm_type.lower() == 'batchnorm':
|
||||||
|
result.add(tf.keras.layers.BatchNormalization())
|
||||||
|
elif norm_type.lower() == 'instancenorm':
|
||||||
|
result.add(InstanceNormalization())
|
||||||
|
|
||||||
|
result.add(tf.keras.layers.LeakyReLU())
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def upsample(filters, size, norm_type='batchnorm', apply_dropout=False):
|
||||||
|
"""Upsamples an input.
|
||||||
|
|
||||||
|
Conv2DTranspose => Batchnorm => Dropout => Relu
|
||||||
|
|
||||||
|
Args:
|
||||||
|
filters: number of filters
|
||||||
|
size: filter size
|
||||||
|
norm_type: Normalization type; either 'batchnorm' or 'instancenorm'.
|
||||||
|
apply_dropout: If True, adds the dropout layer
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Upsample Sequential Model
|
||||||
|
"""
|
||||||
|
|
||||||
|
initializer = tf.random_normal_initializer(0., 0.02)
|
||||||
|
|
||||||
|
result = tf.keras.Sequential()
|
||||||
|
result.add(
|
||||||
|
tf.keras.layers.Conv2DTranspose(filters, size, strides=2,
|
||||||
|
padding='same',
|
||||||
|
kernel_initializer=initializer,
|
||||||
|
use_bias=False))
|
||||||
|
|
||||||
|
if norm_type.lower() == 'batchnorm':
|
||||||
|
result.add(tf.keras.layers.BatchNormalization())
|
||||||
|
elif norm_type.lower() == 'instancenorm':
|
||||||
|
result.add(InstanceNormalization())
|
||||||
|
|
||||||
|
if apply_dropout:
|
||||||
|
result.add(tf.keras.layers.Dropout(0.5))
|
||||||
|
|
||||||
|
result.add(tf.keras.layers.ReLU())
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def unet_generator(output_channels, norm_type='batchnorm'):
|
||||||
|
"""Modified u-net generator model (https://arxiv.org/abs/1611.07004).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
output_channels: Output channels
|
||||||
|
norm_type: Type of normalization. Either 'batchnorm' or 'instancenorm'.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Generator model
|
||||||
|
"""
|
||||||
|
|
||||||
|
down_stack = [
|
||||||
|
downsample(64, 4, norm_type, apply_norm=False), # (bs, 128, 128, 64)
|
||||||
|
downsample(128, 4, norm_type), # (bs, 64, 64, 128)
|
||||||
|
downsample(256, 4, norm_type), # (bs, 32, 32, 256)
|
||||||
|
downsample(512, 4, norm_type), # (bs, 16, 16, 512)
|
||||||
|
downsample(512, 4, norm_type), # (bs, 8, 8, 512)
|
||||||
|
downsample(512, 4, norm_type), # (bs, 4, 4, 512)
|
||||||
|
downsample(512, 4, norm_type), # (bs, 2, 2, 512)
|
||||||
|
downsample(512, 4, norm_type), # (bs, 1, 1, 512)
|
||||||
|
]
|
||||||
|
|
||||||
|
up_stack = [
|
||||||
|
upsample(512, 4, norm_type, apply_dropout=True), # (bs, 2, 2, 1024)
|
||||||
|
upsample(512, 4, norm_type, apply_dropout=True), # (bs, 4, 4, 1024)
|
||||||
|
upsample(512, 4, norm_type, apply_dropout=True), # (bs, 8, 8, 1024)
|
||||||
|
upsample(512, 4, norm_type), # (bs, 16, 16, 1024)
|
||||||
|
upsample(256, 4, norm_type), # (bs, 32, 32, 512)
|
||||||
|
upsample(128, 4, norm_type), # (bs, 64, 64, 256)
|
||||||
|
upsample(64, 4, norm_type), # (bs, 128, 128, 128)
|
||||||
|
]
|
||||||
|
|
||||||
|
initializer = tf.random_normal_initializer(0., 0.02)
|
||||||
|
last = tf.keras.layers.Conv2DTranspose(
|
||||||
|
output_channels, 4, strides=2,
|
||||||
|
padding='same', kernel_initializer=initializer,
|
||||||
|
activation='tanh') # (bs, 256, 256, 3)
|
||||||
|
|
||||||
|
concat = tf.keras.layers.Concatenate()
|
||||||
|
|
||||||
|
inputs = tf.keras.layers.Input(shape=[None, None, 3])
|
||||||
|
x = inputs
|
||||||
|
|
||||||
|
# Downsampling through the model
|
||||||
|
skips = []
|
||||||
|
for down in down_stack:
|
||||||
|
x = down(x)
|
||||||
|
skips.append(x)
|
||||||
|
|
||||||
|
skips = reversed(skips[:-1])
|
||||||
|
|
||||||
|
# Upsampling and establishing the skip connections
|
||||||
|
for up, skip in zip(up_stack, skips):
|
||||||
|
x = up(x)
|
||||||
|
x = concat([x, skip])
|
||||||
|
|
||||||
|
x = last(x)
|
||||||
|
|
||||||
|
return tf.keras.Model(inputs=inputs, outputs=x)
|
||||||
|
|
||||||
|
|
||||||
|
def discriminator(norm_type='batchnorm', target=True):
|
||||||
|
"""PatchGan discriminator model (https://arxiv.org/abs/1611.07004).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
norm_type: Type of normalization. Either 'batchnorm' or 'instancenorm'.
|
||||||
|
target: Bool, indicating whether target image is an input or not.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Discriminator model
|
||||||
|
"""
|
||||||
|
|
||||||
|
initializer = tf.random_normal_initializer(0., 0.02)
|
||||||
|
|
||||||
|
inp = tf.keras.layers.Input(shape=[None, None, 3], name='input_image')
|
||||||
|
x = inp
|
||||||
|
|
||||||
|
if target:
|
||||||
|
tar = tf.keras.layers.Input(shape=[None, None, 3], name='target_image')
|
||||||
|
x = tf.keras.layers.concatenate([inp, tar]) # (bs, 256, 256, channels*2)
|
||||||
|
|
||||||
|
down1 = downsample(64, 4, norm_type, False)(x) # (bs, 128, 128, 64)
|
||||||
|
down2 = downsample(128, 4, norm_type)(down1) # (bs, 64, 64, 128)
|
||||||
|
down3 = downsample(256, 4, norm_type)(down2) # (bs, 32, 32, 256)
|
||||||
|
|
||||||
|
zero_pad1 = tf.keras.layers.ZeroPadding2D()(down3) # (bs, 34, 34, 256)
|
||||||
|
conv = tf.keras.layers.Conv2D(
|
||||||
|
512, 4, strides=1, kernel_initializer=initializer,
|
||||||
|
use_bias=False)(zero_pad1) # (bs, 31, 31, 512)
|
||||||
|
|
||||||
|
if norm_type.lower() == 'batchnorm':
|
||||||
|
norm1 = tf.keras.layers.BatchNormalization()(conv)
|
||||||
|
elif norm_type.lower() == 'instancenorm':
|
||||||
|
norm1 = InstanceNormalization()(conv)
|
||||||
|
|
||||||
|
leaky_relu = tf.keras.layers.LeakyReLU()(norm1)
|
||||||
|
|
||||||
|
zero_pad2 = tf.keras.layers.ZeroPadding2D()(leaky_relu) # (bs, 33, 33, 512)
|
||||||
|
|
||||||
|
last = tf.keras.layers.Conv2D(
|
||||||
|
1, 4, strides=1,
|
||||||
|
kernel_initializer=initializer)(zero_pad2) # (bs, 30, 30, 1)
|
||||||
|
|
||||||
|
if target:
|
||||||
|
return tf.keras.Model(inputs=[inp, tar], outputs=last)
|
||||||
|
else:
|
||||||
|
return tf.keras.Model(inputs=inp, outputs=last)
|
||||||
|
|
||||||
|
|
||||||
|
def get_checkpoint_prefix():
|
||||||
|
checkpoint_dir = './training_checkpoints'
|
||||||
|
checkpoint_prefix = os.path.join(checkpoint_dir, 'ckpt')
|
||||||
|
|
||||||
|
return checkpoint_prefix
|
||||||
|
|
||||||
|
|
||||||
|
class Pix2pix(object):
|
||||||
|
"""Pix2pix class.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
epochs: Number of epochs.
|
||||||
|
enable_function: If true, train step is decorated with tf.function.
|
||||||
|
buffer_size: Shuffle buffer size..
|
||||||
|
batch_size: Batch size.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, epochs, enable_function):
|
||||||
|
self.epochs = epochs
|
||||||
|
self.enable_function = enable_function
|
||||||
|
self.lambda_value = 100
|
||||||
|
self.loss_object = tf.keras.losses.BinaryCrossentropy(from_logits=True)
|
||||||
|
self.generator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
|
||||||
|
self.discriminator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
|
||||||
|
self.generator = unet_generator(output_channels=3)
|
||||||
|
self.discriminator = discriminator()
|
||||||
|
self.checkpoint = tf.train.Checkpoint(
|
||||||
|
generator_optimizer=self.generator_optimizer,
|
||||||
|
discriminator_optimizer=self.discriminator_optimizer,
|
||||||
|
generator=self.generator,
|
||||||
|
discriminator=self.discriminator)
|
||||||
|
|
||||||
|
def discriminator_loss(self, disc_real_output, disc_generated_output):
|
||||||
|
real_loss = self.loss_object(
|
||||||
|
tf.ones_like(disc_real_output), disc_real_output)
|
||||||
|
|
||||||
|
generated_loss = self.loss_object(tf.zeros_like(
|
||||||
|
disc_generated_output), disc_generated_output)
|
||||||
|
|
||||||
|
total_disc_loss = real_loss + generated_loss
|
||||||
|
|
||||||
|
return total_disc_loss
|
||||||
|
|
||||||
|
def generator_loss(self, disc_generated_output, gen_output, target):
|
||||||
|
gan_loss = self.loss_object(tf.ones_like(
|
||||||
|
disc_generated_output), disc_generated_output)
|
||||||
|
|
||||||
|
# mean absolute error
|
||||||
|
l1_loss = tf.reduce_mean(tf.abs(target - gen_output))
|
||||||
|
total_gen_loss = gan_loss + (self.lambda_value * l1_loss)
|
||||||
|
return total_gen_loss
|
||||||
|
|
||||||
|
def train_step(self, input_image, target_image):
|
||||||
|
"""One train step over the generator and discriminator model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_image: Input Image.
|
||||||
|
target_image: Target image.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
generator loss, discriminator loss.
|
||||||
|
"""
|
||||||
|
with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
|
||||||
|
gen_output = self.generator(input_image, training=True)
|
||||||
|
|
||||||
|
disc_real_output = self.discriminator(
|
||||||
|
[input_image, target_image], training=True)
|
||||||
|
disc_generated_output = self.discriminator(
|
||||||
|
[input_image, gen_output], training=True)
|
||||||
|
|
||||||
|
gen_loss = self.generator_loss(
|
||||||
|
disc_generated_output, gen_output, target_image)
|
||||||
|
disc_loss = self.discriminator_loss(
|
||||||
|
disc_real_output, disc_generated_output)
|
||||||
|
|
||||||
|
generator_gradients = gen_tape.gradient(
|
||||||
|
gen_loss, self.generator.trainable_variables)
|
||||||
|
discriminator_gradients = disc_tape.gradient(
|
||||||
|
disc_loss, self.discriminator.trainable_variables)
|
||||||
|
|
||||||
|
self.generator_optimizer.apply_gradients(zip(
|
||||||
|
generator_gradients, self.generator.trainable_variables))
|
||||||
|
self.discriminator_optimizer.apply_gradients(zip(
|
||||||
|
discriminator_gradients, self.discriminator.trainable_variables))
|
||||||
|
|
||||||
|
return gen_loss, disc_loss
|
||||||
|
|
||||||
|
def train(self, dataset, checkpoint_pr):
|
||||||
|
"""Train the GAN for x number of epochs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dataset: train dataset.
|
||||||
|
checkpoint_pr: prefix in which the checkpoints are stored.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Time for each epoch.
|
||||||
|
"""
|
||||||
|
time_list = []
|
||||||
|
if self.enable_function:
|
||||||
|
self.train_step = tf.function(self.train_step)
|
||||||
|
|
||||||
|
for epoch in range(self.epochs):
|
||||||
|
start_time = time.time()
|
||||||
|
for input_image, target_image in dataset:
|
||||||
|
gen_loss, disc_loss = self.train_step(input_image, target_image)
|
||||||
|
|
||||||
|
wall_time_sec = time.time() - start_time
|
||||||
|
time_list.append(wall_time_sec)
|
||||||
|
|
||||||
|
# saving (checkpoint) the model every 20 epochs
|
||||||
|
if (epoch + 1) % 20 == 0:
|
||||||
|
self.checkpoint.save(file_prefix=checkpoint_pr)
|
||||||
|
|
||||||
|
template = 'Epoch {}, Generator loss {}, Discriminator Loss {}'
|
||||||
|
print (template.format(epoch, gen_loss, disc_loss))
|
||||||
|
|
||||||
|
return time_list
|
||||||
|
|
||||||
|
|
||||||
|
def run_main(argv):
|
||||||
|
del argv
|
||||||
|
kwargs = {'epochs': FLAGS.epochs, 'enable_function': FLAGS.enable_function,
|
||||||
|
'path': FLAGS.path, 'buffer_size': FLAGS.buffer_size,
|
||||||
|
'batch_size': FLAGS.batch_size}
|
||||||
|
main(**kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def main(epochs, enable_function, path, buffer_size, batch_size):
|
||||||
|
path_to_folder = path
|
||||||
|
|
||||||
|
pix2pix_object = Pix2pix(epochs, enable_function)
|
||||||
|
|
||||||
|
train_dataset, _ = create_dataset(
|
||||||
|
os.path.join(path_to_folder, 'train/*.jpg'),
|
||||||
|
os.path.join(path_to_folder, 'test/*.jpg'),
|
||||||
|
buffer_size, batch_size)
|
||||||
|
checkpoint_pr = get_checkpoint_prefix()
|
||||||
|
print ('Training ...')
|
||||||
|
return pix2pix_object.train(train_dataset, checkpoint_pr)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
app.run(run_main)
|
BIN
requirements.txt
BIN
requirements.txt
Binary file not shown.
Loading…
Reference in New Issue
Block a user