愿我们终有重逢之时,而你还记得我们曾经讨论的话题。
group 868373192
second group 277356808
1. 导入必要的库和模块
import os os.system('pip install gradio_imageslider') import gradio as gr from gradio_imageslider import ImageSlider from loadimg import load_img from transformers import AutoModelForImageSegmentation import torch from torchvision import transforms from modelscope import snapshot_download
解读:
-
os.system('pip install gradio_imageslider')
: 使用os.system
命令安装gradio_imageslider
库。这个库可能是用于图像滑块的 Gradio 组件。 -
import gradio as gr
: 导入 Gradio 库,用于创建交互式 Web 界面。 -
from gradio_imageslider import ImageSlider
: 从gradio_imageslider
库中导入ImageSlider
组件。 -
from loadimg import load_img
: 导入自定义的load_img
函数,用于加载图像。 -
from transformers import AutoModelForImageSegmentation
: 导入 Hugging Face 的transformers
库中的AutoModelForImageSegmentation
模型,用于图像分割。 -
import torch
: 导入 PyTorch 库,用于深度学习模型的训练和推理。 -
from torchvision import transforms
: 导入torchvision
库中的transforms
模块,用于图像预处理。 -
from modelscope import snapshot_download
: 导入modelscope
库中的snapshot_download
函数,用于下载预训练模型。
2. 设置 PyTorch 的浮点运算精度
torch.set_float32_matmul_precision(["high", "highest"][0])
解读:
-
torch.set_float32_matmul_precision
: 设置 PyTorch 中浮点矩阵乘法的精度。这里选择的是"high"
精度。
3. 下载并加载预训练模型
model_dir = snapshot_download("modelscope/BiRefNet") birefnet = AutoModelForImageSegmentation.from_pretrained(model_dir, trust_remote_code=True ) birefnet.to("cuda")
解读:
-
snapshot_download("modelscope/BiRefNet")
: 从modelscope
下载BiRefNet
模型的预训练权重。 -
AutoModelForImageSegmentation.from_pretrained(model_dir, trust_remote_code=True)
: 使用 Hugging Face 的AutoModelForImageSegmentation
类加载预训练模型。trust_remote_code=True
表示信任远程代码。 -
birefnet.to("cuda")
: 将模型移动到 GPU 上以加速推理。
4. 定义图像预处理管道
transform_image = transforms.Compose([transforms.Resize((1024, 1024)),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),] )
解读:
-
transforms.Compose
: 将多个图像变换操作组合在一起。 -
transforms.Resize((1024, 1024))
: 将图像调整为 1024x1024 像素。 -
transforms.ToTensor()
: 将图像转换为 PyTorch 张量。 -
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
: 对图像进行归一化处理,使用 ImageNet 的均值和标准差。
5. 定义图像分割函数
def fn(image):im = load_img(image, output_type="pil")im = im.convert("RGB")image_size = im.sizeorigin = im.copy()image = load_img(im)input_images = transform_image(image).unsqueeze(0).to("cuda")# Predictionwith torch.no_grad():preds = birefnet(input_images)[-1].sigmoid().cpu()pred = preds[0].squeeze()pred_pil = transforms.ToPILImage()(pred)mask = pred_pil.resize(image_size)image.putalpha(mask)return (image, origin)
解读:
-
fn(image)
: 这是一个用于图像分割的函数。 -
im = load_img(image, output_type="pil")
: 加载图像并将其转换为 PIL 图像对象。 -
im = im.convert("RGB")
: 将图像转换为 RGB 格式。 -
image_size = im.size
: 获取图像的尺寸。 -
origin = im.copy()
: 保存原始图像的副本。 -
input_images = transform_image(image).unsqueeze(0).to("cuda")
: 对图像进行预处理,并将其转换为模型输入格式。 -
with torch.no_grad()
: 禁用梯度计算,以节省内存并加速推理。 -
preds = birefnet(input_images)[-1].sigmoid().cpu()
: 使用模型进行预测,并将结果转换为 CPU 上的张量。 -
pred = preds[0].squeeze()
: 提取预测结果并去除多余的维度。 -
pred_pil = transforms.ToPILImage()(pred)
: 将预测结果转换为 PIL 图像。 -
mask = pred_pil.resize(image_size)
: 将预测结果调整为原始图像的尺寸。 -
image.putalpha(mask)
: 将预测的掩码应用到原始图像上。 -
return (image, origin)
: 返回处理后的图像和原始图像。
6. 创建 Gradio 界面组件
slider1 = ImageSlider(label="birefnet", type="pil") slider2 = ImageSlider(label="birefnet", type="pil") image = gr.Image(label="上传图片") text = gr.Textbox(label="粘贴图片URL")
解读:
-
slider1
和slider2
: 创建两个ImageSlider
组件,用于显示图像分割结果。 -
image
: 创建一个gr.Image
组件,用于上传图片。 -
text
: 创建一个gr.Textbox
组件,用于粘贴图片的 URL。
7. 加载示例图像
chameleon = load_img("butterfly.jpg", output_type="pil") url = "https://hips.hearstapps.com/hmg-prod/images/gettyimages-1229892983-square.jpg"
解读:
-
chameleon
: 加载一个名为butterfly.jpg
的示例图像。 -
url
: 定义一个示例图像的 URL。
8. 创建 Gradio 界面
tab1 = gr.Interface(fn, inputs=image, outputs=slider1, examples=[chameleon], api_name="image" )tab2 = gr.Interface(fn, inputs=text, outputs=slider2, examples=[url], api_name="text")demo = gr.TabbedInterface([tab1, tab2], ["图片", "链接"], title="birefnet通用抠图" )
解读:
-
tab1
: 创建一个 Gradio 界面,使用image
组件作为输入,slider1
作为输出,并提供一个示例图像。 -
tab2
: 创建另一个 Gradio 界面,使用text
组件作为输入,slider2
作为输出,并提供一个示例 URL。 -
demo
: 创建一个带有标签的 Tabbed 界面,包含两个选项卡:一个用于上传图片,另一个用于粘贴图片 URL。
9. 启动 Gradio 界面
python
复制
if __name__ == "__main__":demo.launch(max_threads=150, share=False, inbrowser=True, server_name="0.0.0.0", server_port=8001)
解读:
-
if __name__ == "__main__":
: 只有在直接运行脚本时才会执行以下代码。 -
demo.launch(...)
: 启动 Gradio 界面,设置最大线程数为 150,不共享链接,自动打开浏览器,服务器地址为0.0.0.0
,端口为 8001。
功能介绍
这个代码实现了一个基于 BiRefNet
模型的图像分割应用。用户可以通过上传图片或粘贴图片 URL 的方式,使用预训练的 BiRefNet
模型对图像进行分割,并生成一个带有分割掩码的图像。最终结果通过 Gradio 界面展示,用户可以在界面上直观地看到分割效果。