9 Star 28 Fork 2

linzhong / BigGAN

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
ops.py 10.90 KB
一键复制 编辑 原始数据 按行查看 历史
linzhong 提交于 2021-02-22 15:24 . add source code
import tensorflow as tf
import tensorflow.contrib as tf_contrib
# Xavier : tf_contrib.layers.xavier_initializer()
# He : tf_contrib.layers.variance_scaling_initializer()
# Normal : tf.random_normal_initializer(mean=0.0, stddev=0.02)
# l2_decay : tf_contrib.layers.l2_regularizer(0.0001)
weight_init = tf_contrib.layers.xavier_initializer()
weight_regularizer = None
weight_regularizer_fully = None
##################################################################################
# Layer
##################################################################################
def conv(x, channels, kernel=4, stride=2, pad=0, pad_type='zero', use_bias=True, sn=False, scope='conv_0'):
with tf.variable_scope(scope):
if pad > 0:
h = x.get_shape().as_list()[1]
if h % stride == 0:
pad = pad * 2
else:
pad = max(kernel - (h % stride), 0)
pad_top = pad // 2
pad_bottom = pad - pad_top
pad_left = pad // 2
pad_right = pad - pad_left
if pad_type == 'zero':
x = tf.pad(x, [[0, 0], [pad_top, pad_bottom], [pad_left, pad_right], [0, 0]])
if pad_type == 'reflect':
x = tf.pad(x, [[0, 0], [pad_top, pad_bottom], [pad_left, pad_right], [0, 0]], mode='REFLECT')
if sn:
w = tf.get_variable("kernel", shape=[kernel, kernel, x.get_shape()[-1], channels], initializer=weight_init,
regularizer=weight_regularizer)
x = tf.nn.conv2d(input=x, filter=spectral_norm(w),
strides=[1, stride, stride, 1], padding='VALID')
if use_bias:
bias = tf.get_variable("bias", [channels], initializer=tf.constant_initializer(0.0))
x = tf.nn.bias_add(x, bias)
else:
x = tf.layers.conv2d(inputs=x, filters=channels,
kernel_size=kernel, kernel_initializer=weight_init,
kernel_regularizer=weight_regularizer,
strides=stride, use_bias=use_bias)
return x
def deconv(x, channels, kernel=4, stride=2, padding='SAME', use_bias=True, sn=False, scope='deconv_0'):
with tf.variable_scope(scope):
x_shape = x.get_shape().as_list()
if padding == 'SAME':
output_shape = [x_shape[0], x_shape[1] * stride, x_shape[2] * stride, channels]
else:
output_shape = [x_shape[0], x_shape[1] * stride + max(kernel - stride, 0),
x_shape[2] * stride + max(kernel - stride, 0), channels]
if sn:
w = tf.get_variable("kernel", shape=[kernel, kernel, channels, x.get_shape()[-1]], initializer=weight_init,
regularizer=weight_regularizer)
x = tf.nn.conv2d_transpose(x, filter=spectral_norm(w), output_shape=output_shape,
strides=[1, stride, stride, 1], padding=padding)
if use_bias:
bias = tf.get_variable("bias", [channels], initializer=tf.constant_initializer(0.0))
x = tf.nn.bias_add(x, bias)
else:
x = tf.layers.conv2d_transpose(inputs=x, filters=channels,
kernel_size=kernel, kernel_initializer=weight_init,
kernel_regularizer=weight_regularizer,
strides=stride, padding=padding, use_bias=use_bias)
return x
def fully_connected(x, units, use_bias=True, sn=False, scope='linear'):
with tf.variable_scope(scope):
x = flatten(x)
shape = x.get_shape().as_list()
channels = shape[-1]
if sn:
w = tf.get_variable("kernel", [channels, units], tf.float32,
initializer=weight_init, regularizer=weight_regularizer_fully)
if use_bias:
bias = tf.get_variable("bias", [units],
initializer=tf.constant_initializer(0.0))
x = tf.matmul(x, spectral_norm(w)) + bias
else:
x = tf.matmul(x, spectral_norm(w))
else:
x = tf.layers.dense(x, units=units, kernel_initializer=weight_init,
kernel_regularizer=weight_regularizer_fully,
use_bias=use_bias)
return x
def flatten(x) :
return tf.layers.flatten(x)
def hw_flatten(x) :
return tf.reshape(x, shape=[x.shape[0], -1, x.shape[-1]])
##################################################################################
# Residual-block
##################################################################################
def up_resblock(x_init, channels, use_bias=True, is_training=True, sn=False, scope='resblock'):
with tf.variable_scope(scope):
with tf.variable_scope('res1'):
x = batch_norm(x_init, is_training)
x = relu(x)
x = up_sample(x, scale_factor=2)
x = conv(x, channels, kernel=3, stride=1, pad=1, pad_type='zero', use_bias=False, sn=sn)
with tf.variable_scope('res2'):
x = batch_norm(x, is_training)
x = relu(x)
x = conv(x, channels, kernel=3, stride=1, pad=1, pad_type='zero', use_bias=use_bias, sn=sn)
with tf.variable_scope('shortcut'):
x_init = up_sample(x_init, scale_factor=2)
x_init = conv(x_init, channels, kernel=1, stride=1, use_bias=False, sn=sn)
return x + x_init
def down_resblock(x_init, channels, to_down=True, use_bias=True, sn=False, scope='resblock'):
with tf.variable_scope(scope):
init_channel = x_init.shape.as_list()[-1]
with tf.variable_scope('res1'):
x = lrelu(x_init, 0.2)
x = conv(x, channels, kernel=3, stride=1, pad=1, pad_type='zero', use_bias=use_bias, sn=sn)
with tf.variable_scope('res2'):
x = lrelu(x, 0.2)
x = conv(x, channels, kernel=3, stride=1, pad=1, pad_type='zero', use_bias=use_bias, sn=sn)
if to_down :
x = down_sample(x)
if to_down or init_channel != channels :
with tf.variable_scope('shortcut'):
x_init = conv(x_init, channels, kernel=1, stride=1, use_bias=use_bias, sn=sn)
if to_down :
x_init = down_sample(x_init)
return x + x_init
def init_down_resblock(x_init, channels, use_bias=True, sn=False, scope='resblock'):
with tf.variable_scope(scope):
with tf.variable_scope('res1'):
x = conv(x_init, channels, kernel=3, stride=1, pad=1, pad_type='zero', use_bias=use_bias, sn=sn)
x = lrelu(x, 0.2)
with tf.variable_scope('res2'):
x = conv(x, channels, kernel=3, stride=1, pad=1, pad_type='zero', use_bias=use_bias, sn=sn)
x = down_sample(x)
with tf.variable_scope('shortcut'):
x_init = down_sample(x_init)
x_init = conv(x_init, channels, kernel=1, stride=1, use_bias=use_bias, sn=sn)
return x + x_init
##################################################################################
# Sampling
##################################################################################
def global_avg_pooling(x):
gap = tf.reduce_mean(x, axis=[1, 2])
return gap
def global_sum_pooling(x) :
gsp = tf.reduce_sum(x, axis=[1, 2])
return gsp
def up_sample(x, scale_factor=2):
_, h, w, _ = x.get_shape().as_list()
new_size = [h * scale_factor, w * scale_factor]
return tf.image.resize_nearest_neighbor(x, size=new_size)
def down_sample(x):
return tf.layers.average_pooling2d(x, pool_size=2, strides=2, padding='SAME')
def max_pooling(x) :
return tf.layers.max_pooling2d(x, pool_size=2, strides=2, padding='SAME')
##################################################################################
# Activation function
##################################################################################
def lrelu(x, alpha=0.2):
return tf.nn.leaky_relu(x, alpha)
def relu(x):
return tf.nn.relu(x)
def tanh(x):
return tf.tanh(x)
##################################################################################
# Normalization function
##################################################################################
def batch_norm(x, is_training=True, scope='batch_norm'):
return tf_contrib.layers.batch_norm(x,
decay=0.9, epsilon=1e-05,
center=True, scale=True, updates_collections=None,
is_training=is_training, scope=scope)
def spectral_norm(w, iteration=1):
w_shape = w.shape.as_list()
w = tf.reshape(w, [-1, w_shape[-1]])
u = tf.get_variable("u", [1, w_shape[-1]], initializer=tf.random_normal_initializer(), trainable=False)
u_hat = u
v_hat = None
for i in range(iteration):
"""
power iteration
Usually iteration = 1 will be enough
"""
v_ = tf.matmul(u_hat, tf.transpose(w))
v_hat = tf.nn.l2_normalize(v_)
u_ = tf.matmul(v_hat, w)
u_hat = tf.nn.l2_normalize(u_)
u_hat = tf.stop_gradient(u_hat)
v_hat = tf.stop_gradient(v_hat)
sigma = tf.matmul(tf.matmul(v_hat, w), tf.transpose(u_hat))
with tf.control_dependencies([u.assign(u_hat)]):
w_norm = w / sigma
w_norm = tf.reshape(w_norm, w_shape)
return w_norm
##################################################################################
# Loss function
##################################################################################
def discriminator_loss(loss_func, real, fake):
real_loss = 0
fake_loss = 0
if loss_func.__contains__('wgan') :
real_loss = -tf.reduce_mean(real)
fake_loss = tf.reduce_mean(fake)
if loss_func == 'lsgan' :
real_loss = tf.reduce_mean(tf.squared_difference(real, 1.0))
fake_loss = tf.reduce_mean(tf.square(fake))
if loss_func == 'gan' or loss_func == 'dragan' :
real_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.ones_like(real), logits=real))
fake_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.zeros_like(fake), logits=fake))
if loss_func == 'hinge' :
real_loss = tf.reduce_mean(relu(1.0 - real))
fake_loss = tf.reduce_mean(relu(1.0 + fake))
loss = real_loss + fake_loss
return loss
def generator_loss(loss_func, fake):
fake_loss = 0
if loss_func.__contains__('wgan') :
fake_loss = -tf.reduce_mean(fake)
if loss_func == 'lsgan' :
fake_loss = tf.reduce_mean(tf.squared_difference(fake, 1.0))
if loss_func == 'gan' or loss_func == 'dragan' :
fake_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.ones_like(fake), logits=fake))
if loss_func == 'hinge' :
fake_loss = -tf.reduce_mean(fake)
loss = fake_loss
return loss
Python
1
https://gitee.com/echo_lin/big-gan.git
git@gitee.com:echo_lin/big-gan.git
echo_lin
big-gan
BigGAN
master

搜索帮助