当前位置: 首页> 游戏> 手游 > 微信建站网站_北京校园文化设计公司_引擎优化seo怎么做_网站交易网

微信建站网站_北京校园文化设计公司_引擎优化seo怎么做_网站交易网

时间:2025/7/11 8:46:38来源:https://blog.csdn.net/weixin_41369892/article/details/146925527 浏览次数:0次
微信建站网站_北京校园文化设计公司_引擎优化seo怎么做_网站交易网

详细代码及训练得到的8倍超分辨率模型已放在GitHub

Github: SuperResolution-DDIM-SwinUnet

简介

  • 在DIV2K数据集(800张2K图像)上训练了一个8倍超分辨率模型,采用了和sr3一样的:将低分辨率图像和噪声拼接输入模型。不过没有采用sr3的直接输入噪声强度,而是继续沿用输入去燥步骤t的方法,并增加了DDPM的步数到1000(如果仅是100步的话,输出结果的噪点会比较多)。

  • 效果图放在了Github的result目录里,引入了DDIM采样(这也是使用t作为时间条件的好处),从结果看DDIM仅需采样40步效果就和DDPM采样1000步相当了。而DDIM采样1步或2步也能大体还原,不过质量不高。

不足:

1.可能是使用SwinUnet的关系,超分辨率后的图像总是能隐约看到“小框框”;而且图像大小必须能被256整除(这个其实好解决,resize即可)。
2.只做了一个8倍超分辨率的模型(倍数太大,从效果来看失真率很高),可以考虑做倍率较低的比如2倍和4倍,进行拼接从而实现8倍的效果,可能失真率会好一点。

代码:(run.py、scheduler.py、SwinUnet.py、load_data.py、training.py)
"run.py"
import numpy as np
import torchfrom SwinUnet import SwinUnet
from scheduler import Scheduler
from PIL import Imageimport argparse
import datetime
import osdef main(args):device = torch.device(args.device)model = SwinUnet(channels=3, dim=96, mlp_ratio=4, patch_size=4, window_size=8,depth=[2, 2, 6, 2], nheads=[3, 6, 12, 24]).to(device)sr_ratio = args.sr_ratiomodel.load_state_dict(torch.load(args.model_path, map_location=device))model.eval()scheduler = Scheduler(model, args.denoise_steps)image_path = args.image_pathimg = Image.open(image_path)img_size = img.sizeassert img_size[0] >= 256 and img_size[1] >= 256, "图片的最小尺寸为256"img_size = ((img_size[0] // 256) * 256 * sr_ratio,(img_size[1] // 256) * 256 * sr_ratio)img = img.resize(img_size)img_arr = np.array(img)if img_arr.shape[-1] == 4: img_arr = img_arr[..., :3]img_arr = img_arr.transpose(2, 0, 1) / 255.img_arr = 2 * (img_arr - 0.5)img_arr = torch.from_numpy(img_arr).float().to(device)img_arr = img_arr.unsqueeze(0)if args.use_ddim:y = scheduler.ddim(img_arr, device, sub_sequence_step=args.ddim_sub_sequence_steps)[-1]else:y = scheduler.ddpm(img_arr, device)[-1]y = y.transpose(1, 2, 0)y = (y + 1.) / 2y *= 255.0new_img = Image.fromarray(y.astype(np.uint8))new_img.save(os.path.join(args.results_dir, str(datetime.datetime.now()) + ".png"))if __name__ == '__main__':parser = argparse.ArgumentParser()parser.add_argument("--device", type=str, default="cpu")parser.add_argument("--image_path", type=str)parser.add_argument("--sr_ratio", type=int, default=8)parser.add_argument("--results_dir", type=str, default="./results")parser.add_argument("--denoise_steps", type=int, default=1000)parser.add_argument("--model_path", type=str, default="SwinUNet-SR8.pth")parser.add_argument("--use_ddim", type=int, default=1)parser.add_argument("--ddim_sub_sequence_steps", type=int, default=25)args = parser.parse_args()main(args)
"scheduler.py"
import numpy as npimport torchimport torch.nn.functional as Ffrom tqdm import tqdmdef extract_into_tensor(arr, timesteps, broadcast_shape):res = torch.from_numpy(arr).to(torch.float32).to(device=timesteps.device)[timesteps]while len(res.shape) < len(broadcast_shape):res = res[..., None]return res + torch.zeros(broadcast_shape, device=timesteps.device)class Scheduler:def __init__(self, denoise_model, denoise_steps, beta_start=1e-4, beta_end=0.005):self.model = denoise_modelbetas = np.array(np.linspace(beta_start, beta_end, denoise_steps),dtype=np.float64)self.denoise_steps = denoise_stepsassert len(betas.shape) == 1, "betas must be 1-D"assert (betas > 0).all() and (betas <= 1).all()alphas = 1.0 - betasself.sqrt_alphas = np.sqrt(alphas)self.one_minus_alphas = 1.0 - alphasself.alphas_cumprod = np.cumprod(alphas, axis=0)self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod)self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod)self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1])def q_sample(self, y0, t, noise):return (extract_into_tensor(self.sqrt_alphas_cumprod, t, y0.shape) * y0+ extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, y0.shape) * noise)def training_losses(self, x, y, t):noise = torch.randn_like(y)y_t = self.q_sample(y, t, noise)predict_noise = self.model(torch.cat([x, y_t], dim=1), t)return F.mse_loss(predict_noise, noise)@torch.no_grad()def ddpm(self, x, device):y = torch.randn(*x.shape, device=device)for t in tqdm(reversed(range(0, self.denoise_steps)), total=self.denoise_steps):t = torch.tensor([t], device=device).repeat(x.shape[0])t_mask = (t != 0).float().view(-1, *([1] * (len(y.shape) - 1)))eps = self.model(torch.cat([x, y], dim=1), t)y = y - (extract_into_tensor(self.one_minus_alphas, t, y.shape) * eps/ extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, y.shape))y = y / extract_into_tensor(self.sqrt_alphas, t, y.shape)sigma = torch.sqrt(extract_into_tensor(self.one_minus_alphas, t, y.shape)* (1.0 - extract_into_tensor(self.alphas_cumprod_prev, t, y.shape))/ (1.0 - extract_into_tensor(self.alphas_cumprod, t, y.shape)))y = y + sigma * torch.randn_like(y) * t_masky = y.clip(-1, 1)return y.detach().cpu().numpy()@torch.no_grad()def ddim(self, x, device, eta=0.0, sub_sequence_step=25):# 初始化 y 为高斯噪声y = torch.randn(*x.shape, device=device)# 构造跳步采样的时间序列,从 denoise_steps-1 到 0,每隔 jump 取一个时间步t_seq = list(range(self.denoise_steps - 1, -1, -sub_sequence_step))for i in tqdm(range(len(t_seq)), total=len(t_seq)):# 当前时间步 t 和下一个采样时间步 s(若为最后一步,则 s 设为 0)t = t_seq[i]s = 0 if i == len(t_seq) - 1 else t_seq[i + 1]# 构造与 batch 数量相同的时间步张量t_tensor = torch.tensor([t], device=device).repeat(x.shape[0])s_tensor = torch.tensor([s], device=device).repeat(x.shape[0])# 用模型预测噪声eps = self.model(torch.cat([x, y], dim=1), t_tensor)# 提取当前和下一个时间步对应的 α 累积乘积alpha_bar_t = extract_into_tensor(self.alphas_cumprod, t_tensor, y.shape)alpha_bar_s = extract_into_tensor(self.alphas_cumprod, s_tensor, y.shape)# 根据 DDIM 公式预测原始样本 x0 的估计y0_pred = (y - torch.sqrt(1 - alpha_bar_t) * eps) / torch.sqrt(alpha_bar_t)# 计算控制随机性的 sigmasigma = 0.0if eta > 0.0 and s > 0:sigma = eta * torch.sqrt((1 - alpha_bar_s) / (1 - alpha_bar_t) *(1 - alpha_bar_t / alpha_bar_s))# 利用预测的 x0 和当前噪声方向更新至下一个时间步的样本y = torch.sqrt(alpha_bar_s) * y0_pred + torch.sqrt(1 - alpha_bar_s - sigma ** 2) * eps# 若 eta > 0 则在更新后加入噪声(最后一步不添加)if eta > 0.0 and s > 0:y = y + sigma * torch.randn_like(y)y = y.clip(-1, 1)return y.detach().cpu().numpy()
"SwinUnet.py"
import numpy as np
import torch as th
from torch import nn, einsumimport mathfrom einops import rearrange#############################################
# Sinusoidal 时间步嵌入
#############################################
class SinusoidalTimeEmb(nn.Module):def __init__(self, dim):super().__init__()self.dim = dimdef forward(self, t):device = t.devicehalf_dim = self.dim // 2emb = math.log(10000) / (half_dim - 1)emb = th.exp(th.arange(half_dim, device=device) * -emb)emb = t.float().unsqueeze(1) * emb.unsqueeze(0)emb = th.cat([emb.sin(), emb.cos()], dim=-1)return emb  # [B, dim]#############################################
# 下采样模块:Patch Merging
#############################################
class PatchMerging(nn.Module):def __init__(self, in_channels, out_channels, downscaling_factor):super().__init__()self.downscaling_factor = downscaling_factorself.patch_merge = nn.Unfold(kernel_size=downscaling_factor, stride=downscaling_factor, padding=0)self.linear = nn.Linear(in_channels * downscaling_factor ** 2, out_channels)def forward(self, x):b, h, w, c = x.shapenew_h, new_w = h // self.downscaling_factor, w // self.downscaling_factorx = x.permute(0, 3, 1, 2)x = self.patch_merge(x).view(b, -1, new_h, new_w).permute(0, 2, 3, 1)x = self.linear(x)return x#############################################
# 上采样模块:简单插值上采样
#############################################
class PatchExpanding(nn.Module):def __init__(self, in_channels, out_channels, upscaling_factor):super().__init__()self.upscaling_factor = upscaling_factorself.out_channels = out_channelsself.linear = nn.Linear(in_channels, out_channels * self.upscaling_factor ** 2)def forward(self, x):B, H, W, _ = x.shapex = self.linear(x)x = x.view(B, H, W, self.upscaling_factor, self.upscaling_factor, self.out_channels)x = x.permute(0, 1, 3, 2, 4, 5).contiguous()x = x.view(B, H * self.upscaling_factor, W * self.upscaling_factor, self.out_channels)return x#############################################
# 窗口自注意力机制
#############################################
class WindowAttention(nn.Module):def __init__(self, dim, nheads, window_size, shifted, relative_pos_embedding):super().__init__()head_dim = dim // nheadsself.nheads = nheadsself.scale = head_dim ** -0.5self.window_size = window_sizeself.relative_pos_embedding = relative_pos_embeddingself.shifted = shiftedif self.shifted:displacement = window_size // 2self.cyclic_shift = CyclicShift(-displacement)self.cyclic_back_shift = CyclicShift(displacement)self.upper_lower_mask = nn.Parameter(create_mask(window_size=window_size, displacement=displacement,upper_lower=True, left_right=False), requires_grad=False)self.left_right_mask = nn.Parameter(create_mask(window_size=window_size, displacement=displacement,upper_lower=False, left_right=True), requires_grad=False)self.to_qkv = nn.Linear(dim, dim * 3, bias=False)if self.relative_pos_embedding:self.relative_indices = get_relative_distances(window_size) + window_size - 1self.pos_embedding = nn.Parameter(th.randn(2 * window_size - 1, 2 * window_size - 1))else:self.pos_embedding = nn.Parameter(th.randn(window_size ** 2, window_size ** 2))self.to_out = nn.Linear(dim, dim)def forward(self, x):if self.shifted:x = self.cyclic_shift(x)b, n_h, n_w, _, h = *x.shape, self.nheadsqkv = self.to_qkv(x).chunk(3, dim=-1)nw_h = n_h // self.window_sizenw_w = n_w // self.window_sizeq, k, v = map(lambda t: rearrange(t, 'b (nw_h w_h) (nw_w w_w) (h d) -> b h (nw_h nw_w) (w_h w_w) d',h=h, w_h=self.window_size, w_w=self.window_size), qkv)dots = einsum('b h w i d, b h w j d -> b h w i j', q, k) * self.scaleif self.relative_pos_embedding:dots += self.pos_embedding[self.relative_indices[:, :, 0], self.relative_indices[:, :, 1]]else:dots += self.pos_embeddingif self.shifted:dots[:, :, -nw_w:] += self.upper_lower_maskdots[:, :, nw_w - 1::nw_w] += self.left_right_maskattn = dots.softmax(dim=-1)out = einsum('b h w i j, b h w j d -> b h w i d', attn, v)out = rearrange(out, 'b h (nw_h nw_w) (w_h w_w) d -> b (nw_h w_h) (nw_w w_w) (h d)',h=h, w_h=self.window_size, w_w=self.window_size, nw_h=nw_h, nw_w=nw_w)out = self.to_out(out)if self.shifted:out = self.cyclic_back_shift(out)return outclass CyclicShift(nn.Module):def __init__(self, displacement):super().__init__()self.displacement = displacementdef forward(self, x):return th.roll(x, shifts=(self.displacement, self.displacement), dims=(1, 2))def create_mask(window_size, displacement, upper_lower, left_right):mask = th.zeros(window_size ** 2, window_size ** 2)if upper_lower:mask[-displacement * window_size:, :-displacement * window_size] = float('-inf')mask[:-displacement * window_size, -displacement * window_size:] = float('-inf')if left_right:mask = rearrange(mask, '(h1 w1) (h2 w2) -> h1 w1 h2 w2', h1=window_size, h2=window_size)mask[:, -displacement:, :, :-displacement] = float('-inf')mask[:, :-displacement, :, -displacement:] = float('-inf')mask = rearrange(mask, 'h1 w1 h2 w2 -> (h1 w1) (h2 w2)')return maskdef get_relative_distances(window_size):indices = th.tensor(np.array([[x, y] for x in range(window_size) for y in range(window_size)]))distances = indices[None, :, :] - indices[:, None, :]return distances#############################################
# SwinTransformerBlock: 采用和DiT相同的Adaptive Layer Normalization
#############################################
def modulate(x, shift, scale):return x * (1 + scale[:, None, None, :]) + shift[:, None, None, :]class SwinTransformerAdaLnBlock(nn.Module):def __init__(self, dim, mlp_ratio, nheads, window_size, shifted, relative_pos_embedding):super().__init__()self.dim = dimself.attn = WindowAttention(dim, nheads, window_size, shifted, relative_pos_embedding)self.norm1 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)self.mlp = nn.Sequential(nn.Linear(dim, mlp_ratio * dim),nn.GELU(),nn.Linear(mlp_ratio * dim, dim))self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)self.adaLN_modulation = nn.Sequential(nn.SiLU(),nn.Linear(dim, 6 * dim))def forward(self, x, t):shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(t).chunk(6, dim=1)x = x + gate_msa[:, None, None, :] * self.attn(modulate(self.norm1(x), shift_msa, scale_msa))x = x + gate_mlp[:, None, None, :] * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))return x#############################################
# SwinUnet blocks 各组件
#############################################
def block_forward(block, x, t):for b in block:x = b(x, t[:, :b.dim])return xclass SwinUnetEncoder(nn.Module):def __init__(self, channels, dim, patch_size, depth, mlp_ratio, nheads, window_size, relative_pos_embedding):super().__init__()self.patch_embed = PatchMerging(channels, dim, patch_size)self.block0 = nn.ModuleList([SwinTransformerAdaLnBlock(dim=dim * 1,mlp_ratio=mlp_ratio,nheads=nheads[0],window_size=window_size,shifted=True if i // 2 == 0 else False,relative_pos_embedding=relative_pos_embedding) for i in range(1, depth[0] + 1)])self.patch_merge0 = PatchMerging(dim * 1, dim * 2, downscaling_factor=2)self.block1 = nn.ModuleList([SwinTransformerAdaLnBlock(dim=dim * 2,mlp_ratio=mlp_ratio,nheads=nheads[1],window_size=window_size,shifted=True if i // 2 == 0 else False,relative_pos_embedding=relative_pos_embedding) for i in range(1, depth[1] + 1)])self.patch_merge1 = PatchMerging(dim * 2, dim * 4, downscaling_factor=2)self.block2 = nn.ModuleList([SwinTransformerAdaLnBlock(dim=dim * 4,mlp_ratio=mlp_ratio,nheads=nheads[2],window_size=window_size,shifted=True if i // 2 == 0 else False,relative_pos_embedding=relative_pos_embedding) for i in range(1, depth[2] + 1)])self.patch_merge2 = PatchMerging(dim * 4, dim * 8, downscaling_factor=2)def forward(self, x, t):x = x.permute(0, 2, 3, 1)skip_connections = []x = self.patch_embed(x)x = block_forward(self.block0, x, t)skip_connections.append(x)x = self.patch_merge0(x)x = block_forward(self.block1, x, t)skip_connections.append(x)x = self.patch_merge1(x)x = block_forward(self.block2, x, t)skip_connections.append(x)x = self.patch_merge2(x)return x, skip_connectionsclass SwinUnetDecoder(nn.Module):def __init__(self, channels, dim, patch_size, depth, mlp_ratio, nheads, window_size, relative_pos_embedding):super().__init__()self.patch_expand0 = PatchExpanding(dim * 8, dim * 4, upscaling_factor=2)self.block0 = nn.ModuleList([SwinTransformerAdaLnBlock(dim=dim * 4,mlp_ratio=mlp_ratio,nheads=nheads[2],window_size=window_size,shifted=True if i // 2 == 0 else False,relative_pos_embedding=relative_pos_embedding) for i in range(1, depth[2] + 1)])self.skip0 = nn.Linear(dim * 4 * 2, dim * 4, bias=False)self.patch_expand1 = PatchExpanding(dim * 4, dim * 2, upscaling_factor=2)self.block1 = nn.ModuleList([SwinTransformerAdaLnBlock(dim=dim * 2,mlp_ratio=mlp_ratio,nheads=nheads[1],window_size=window_size,shifted=True if i // 2 == 0 else False,relative_pos_embedding=relative_pos_embedding) for i in range(1, depth[1] + 1)])self.skip1 = nn.Linear(dim * 2 * 2, dim * 2, bias=False)self.patch_expand2 = PatchExpanding(dim * 2, dim * 1, upscaling_factor=2)self.block2 = nn.ModuleList([SwinTransformerAdaLnBlock(dim=dim * 1,mlp_ratio=mlp_ratio,nheads=nheads[0],window_size=window_size,shifted=True if i // 2 == 0 else False,relative_pos_embedding=relative_pos_embedding) for i in range(1, depth[0] + 1)])self.skip2 = nn.Linear(dim * 1 * 2, dim * 1, bias=False)self.patch_to_image = PatchExpanding(dim, channels, patch_size)def forward(self, x, skip_connect, t):x = self.patch_expand0(x)x = th.cat((x, skip_connect[2]), dim=-1)x = self.skip0(x)x = block_forward(self.block0, x, t)x = self.patch_expand1(x)x = th.cat((x, skip_connect[1]), dim=-1)x = self.skip1(x)x = block_forward(self.block1, x, t)x = self.patch_expand2(x)x = th.cat((x, skip_connect[0]), dim=-1)x = self.skip2(x)x = block_forward(self.block2, x, t)x = self.patch_to_image(x)return x.permute(0, 3, 1, 2)#############################################
# SwinUnet: 条件的处理采用直接拼接
#############################################
class SwinUnet(nn.Module):def __init__(self, channels, dim, mlp_ratio, patch_size, window_size, depth, nheads,relative_pos_embedding=True, use_condition=True):super().__init__()self.time_embed = SinusoidalTimeEmb(8 * dim)self.encoder = SwinUnetEncoder(channels=2 * channels if use_condition else channels, dim=dim, patch_size=patch_size,depth=depth[:3], mlp_ratio=mlp_ratio, nheads=nheads[:3],window_size=window_size, relative_pos_embedding=relative_pos_embedding)self.bottleneck = nn.ModuleList([SwinTransformerAdaLnBlock(dim=8 * dim,mlp_ratio=mlp_ratio,nheads=nheads[-1],window_size=window_size,shifted=True if i // 2 == 0 else False,relative_pos_embedding=relative_pos_embedding) for i in range(1, depth[-1] + 1)])self.decoder = SwinUnetDecoder(channels=channels, dim=dim, patch_size=patch_size,depth=depth[:3], mlp_ratio=mlp_ratio, nheads=nheads[:3],window_size=window_size, relative_pos_embedding=relative_pos_embedding)def forward(self, x, t):t = self.time_embed(t)x, skip_connection = self.encoder(x, t)x = block_forward(self.bottleneck, x, t)return self.decoder(x, skip_connection, t)
"load_data.py"
import os
import numpy as npfrom PIL import Image
from torch.utils.data import Datasetdef is_image_file(file_path):# 定义常见的图片文件扩展名image_extensions = {'.jpg', '.jpeg', '.png'}# 获取文件的扩展名并判断是否在图片扩展名集合中file_extension = os.path.splitext(file_path)[1].lower()return file_extension in image_extensionsclass CustomDataset(Dataset):def __init__(self, path, img_size=None, sr_ratio=8):super().__init__()files = os.listdir(path)self.img_size = img_sizeself.files = []for file in files:self.files.append(os.path.join(path, file))self.ratio = sr_ratiodef __len__(self):return len(self.files)def __getitem__(self, idx):hr_img = Image.open(self.files[idx])if self.img_size is not None:hr_img = hr_img.resize(self.img_size)hr_size = hr_img.sizeelse:hr_size = hr_img.sizehr_size = ((hr_size[0] // 256 + 1) * 256, (hr_size[1] // 256 + 1) * 256)hr_img = hr_img.resize(hr_size)hr_arr = np.array(hr_img).transpose(2, 0, 1) / 255.lr_img = hr_img.resize((hr_size[0] // self.ratio, hr_size[1] // self.ratio))lr_img = lr_img.resize(hr_size)lr_arr = np.array(lr_img).transpose(2, 0, 1) / 255.lr_arr = 2 * (lr_arr - 0.5)hr_arr = 2 * (hr_arr - 0.5)return lr_arr, hr_arrclass ImageNetDataset(Dataset):def __init__(self, path, img_size=(256, 256), sr_ratio=8):super().__init__()self.img_size = img_sizeclass_dirs = os.listdir(path)self.files = []for class_dir in class_dirs:files = os.listdir(os.path.join(path, class_dir))for file in files:if is_image_file(os.path.join(path, class_dir, file)):self.files.append(os.path.join(path, class_dir, file))self.ratio = sr_ratiodef __len__(self):return len(self.files)def __getitem__(self, idx):hr_img = Image.open(self.files[idx])hr_img = hr_img.convert("RGB")if self.img_size is not None:hr_img = hr_img.resize(self.img_size)hr_size = hr_img.sizehr_arr = np.array(hr_img).transpose(2, 0, 1) / 255.lr_img = hr_img.resize((hr_size[0] // self.ratio, hr_size[1] // self.ratio))lr_img = lr_img.resize(hr_size)lr_arr = np.array(lr_img).transpose(2, 0, 1) / 255.lr_arr = 2 * (lr_arr - 0.5)hr_arr = 2 * (hr_arr - 0.5)return lr_arr, hr_arr
"training.py"
import torch
import numpy as npfrom torch import optim
from tqdm import tqdm
from torch.autograd import Variable
from torch.utils.data import DataLoaderfrom load_data import CustomDataset
from scheduler import Scheduler
from SwinUnet import SwinUnetif __name__ == '__main__':device = torch.device("mps")batch_size = 16lr = 1e-4epochs = 200denoise_steps = 1000sr_ratio = 8train_dataset = CustomDataset("./DIV2K_train_HR", img_size=(512, 512), sr_ratio=sr_ratio,)train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)model = SwinUnet(channels=3, dim=96, mlp_ratio=4, patch_size=4, window_size=8,depth=[2, 2, 6, 2], nheads=[3, 6, 12, 24]).to(device)model.load_state_dict(torch.load("SwinUNet-SR8.pth", map_location=device))optimizer = optim.AdamW(model.parameters(), lr=lr)scheduler = Scheduler(model, denoise_steps)model.train()for epoch in range(epochs):print('*' * 40)train_loss = []for i, data in tqdm(enumerate(train_loader, 1), total=len(train_loader)):x, y = datax = Variable(x).to(torch.float32).to(device)y = Variable(y).to(torch.float32).to(device)t = torch.randint(low=0, high=denoise_steps, size=(x.shape[0],)).to(device)training_loss = scheduler.training_losses(x, y, t)optimizer.zero_grad()training_loss.backward()optimizer.step()train_loss.append(training_loss.item())torch.save(model.state_dict(), f"unet-sr{sr_ratio}.pth")print('Finish  {}  Loss: {:.6f}'.format(epoch + 1, np.mean(train_loss)))
关键字:微信建站网站_北京校园文化设计公司_引擎优化seo怎么做_网站交易网

版权声明:

本网仅为发布的内容提供存储空间,不对发表、转载的内容提供任何形式的保证。凡本网注明“来源:XXX网络”的作品,均转载自其它媒体,著作权归作者所有,商业转载请联系作者获得授权,非商业转载请注明出处。

我们尊重并感谢每一位作者,均已注明文章来源和作者。如因作品内容、版权或其它问题,请及时与我们联系,联系邮箱:809451989@qq.com,投稿邮箱:809451989@qq.com

责任编辑: