当前位置: 首页> 科技> 名企 > 【深度学习】CycleGAN

【深度学习】CycleGAN

时间:2025/7/14 12:41:16来源:https://blog.csdn.net/a13545564067/article/details/140292087 浏览次数:0次

CycleGAN(Cycle-Consistent Generative Adversarial Network)是一种生成对抗网络(GAN)架构,用于图像到图像的翻译任务,无需成对的训练样本。CycleGAN 可以在两个域之间进行图像转换,例如将马转换为斑马,将白天的风景转换为夜晚的风景等。

CycleGAN 的基本架构

CycleGAN 包含两个生成器和两个判别器:

  • 生成器 G:将图像从域 X 转换到域 Y。
  • 生成器 F:将图像从域 Y 转换到域 X。
  • 判别器 D_X:区分图像是否来自域 X。
  • 判别器 D_Y:区分图像是否来自域 Y。

为了确保转换的图像保留原图像的特征,CycleGAN 使用循环一致性损失(Cycle-Consistency Loss)。即,图像经过两个生成器的循环转换后应尽可能恢复到原图像。

损失函数

CycleGAN 的损失函数包括三部分:

  1. 对抗损失(Adversarial Loss):用于确保生成器生成的图像看起来像目标域中的图像。
  2. 循环一致性损失(Cycle-Consistency Loss):确保图像经过两个生成器的转换后能恢复到原图像。
  3. 身份损失(Identity Loss):确保生成器在生成图像时保留输入图像的特征。

TensorFlow 实现示例

以下是一个使用 TensorFlow 和 Keras 实现 CycleGAN 的简化示例。这个示例展示了如何定义生成器和判别器,以及训练 CycleGAN。

import tensorflow as tf
from tensorflow.keras import layers# 定义生成器模型
def build_generator():inputs = tf.keras.Input(shape=[256, 256, 3])x = layers.Conv2D(64, (7, 7), padding='same')(inputs)x = layers.BatchNormalization()(x)x = layers.ReLU()(x)# 多层卷积和反卷积层(简化版)x = layers.Conv2D(128, (3, 3), strides=2, padding='same')(x)x = layers.BatchNormalization()(x)x = layers.ReLU()(x)x = layers.Conv2DTranspose(64, (3, 3), strides=2, padding='same')(x)x = layers.BatchNormalization()(x)x = layers.ReLU()(x)outputs = layers.Conv2D(3, (7, 7), padding='same', activation='tanh')(x)return tf.keras.Model(inputs, outputs)# 定义判别器模型
def build_discriminator():inputs = tf.keras.Input(shape=[256, 256, 3])x = layers.Conv2D(64, (4, 4), strides=2, padding='same')(inputs)x = layers.LeakyReLU(alpha=0.2)(x)x = layers.Conv2D(128, (4, 4), strides=2, padding='same')(x)x = layers.BatchNormalization()(x)x = layers.LeakyReLU(alpha=0.2)(x)x = layers.Conv2D(256, (4, 4), strides=2, padding='same')(x)x = layers.BatchNormalization()(x)x = layers.LeakyReLU(alpha=0.2)(x)outputs = layers.Conv2D(1, (4, 4), padding='same')(x)return tf.keras.Model(inputs, outputs)# 创建生成器和判别器
G = build_generator()
F = build_generator()
D_X = build_discriminator()
D_Y = build_discriminator()# 定义损失函数
loss_obj = tf.keras.losses.BinaryCrossentropy(from_logits=True)# 定义优化器
generator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
discriminator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)# 对抗损失
def discriminator_loss(real, generated):real_loss = loss_obj(tf.ones_like(real), real)generated_loss = loss_obj(tf.zeros_like(generated), generated)total_loss = real_loss + generated_lossreturn total_loss * 0.5def generator_loss(generated):return loss_obj(tf.ones_like(generated), generated)# 循环一致性损失
def cycle_consistency_loss(real, cycled):return tf.reduce_mean(tf.abs(real - cycled))# 身份损失
def identity_loss(real, same):return tf.reduce_mean(tf.abs(real - same))# 训练步骤
@tf.function
def train_step(real_x, real_y):with tf.GradientTape(persistent=True) as tape:# 生成图像fake_y = G(real_x, training=True)cycled_x = F(fake_y, training=True)fake_x = F(real_y, training=True)cycled_y = G(fake_x, training=True)# 生成的图像与真实图像的相似性same_x = F(real_x, training=True)same_y = G(real_y, training=True)# 判别器判断真假disc_real_x = D_X(real_x, training=True)disc_real_y = D_Y(real_y, training=True)disc_fake_x = D_X(fake_x, training=True)disc_fake_y = D_Y(fake_y, training=True)# 计算损失gen_g_loss = generator_loss(disc_fake_y)gen_f_loss = generator_loss(disc_fake_x)total_cycle_loss = cycle_consistency_loss(real_x, cycled_x) + cycle_consistency_loss(real_y, cycled_y)total_gen_g_loss = gen_g_loss + total_cycle_loss + identity_loss(real_y, same_y) * 0.5total_gen_f_loss = gen_f_loss + total_cycle_loss + identity_loss(real_x, same_x) * 0.5disc_x_loss = discriminator_loss(disc_real_x, disc_fake_x)disc_y_loss = discriminator_loss(disc_real_y, disc_fake_y)# 计算梯度并应用优化器generator_gradients_g = tape.gradient(total_gen_g_loss, G.trainable_variables)generator_gradients_f = tape.gradient(total_gen_f_loss, F.trainable_variables)discriminator_gradients_x = tape.gradient(disc_x_loss, D_X.trainable_variables)discriminator_gradients_y = tape.gradient(disc_y_loss, D_Y.trainable_variables)generator_optimizer.apply_gradients(zip(generator_gradients_g, G.trainable_variables))generator_optimizer.apply_gradients(zip(generator_gradients_f, F.trainable_variables))discriminator_optimizer.apply_gradients(zip(discriminator_gradients_x, D_X.trainable_variables))discriminator_optimizer.apply_gradients(zip(discriminator_gradients_y, D_Y.trainable_variables))# 训练循环
def train(dataset, epochs):for epoch in range(epochs):for real_x, real_y in dataset:train_step(real_x, real_y)# 示例数据集(这里需要你自己的数据)
# dataset = tf.data.Dataset.from_tensor_slices((real_x_images, real_y_images)).batch(1)# 训练模型
# train(dataset, epochs=100)
解释
  1. 生成器和判别器

    • 使用卷积和反卷积层(转置卷积)定义生成器模型。
    • 使用卷积层定义判别器模型。
  2. 损失函数

    • 对抗损失用于生成器和判别器。
    • 循环一致性损失确保图像能在转换后恢复。
    • 身份损失确保生成器保留输入图像的特征。
  3. 优化器

    • 使用 Adam 优化器,学习率为 2e-4beta_1 设置为 0.5。
  4. 训练步骤

    • 定义训练步骤函数 train_step,包括前向传播、计算损失和应用梯度。
    • @tf.function 装饰器用于加速训练步骤的执行。
  5. 训练循环

    • 定义训练循环函数 train,迭代数据集并调用 train_step

结论

CycleGAN 是一种强大的模型,可以在没有成对样本的情况下进行图像到图像的转换。通过定义生成器和判别器,以及使用对抗损失、循环一致性损失和身份损失,CycleGAN 能够学习在两个域之间进行有效的图像转换。这个示例提供了一个基本的实现框架,你可以根据具体任务和数据集进行调整和扩展。

关键字:【深度学习】CycleGAN

版权声明:

本网仅为发布的内容提供存储空间,不对发表、转载的内容提供任何形式的保证。凡本网注明“来源:XXX网络”的作品,均转载自其它媒体,著作权归作者所有,商业转载请联系作者获得授权,非商业转载请注明出处。

我们尊重并感谢每一位作者,均已注明文章来源和作者。如因作品内容、版权或其它问题,请及时与我们联系,联系邮箱:809451989@qq.com,投稿邮箱:809451989@qq.com

责任编辑: