Python案例如何实现图像超分?

wen python案例 3

本文目录导读:

Python案例如何实现图像超分?

  1. 环境准备
  2. 传统插值方法(基础版)
  3. 基于SRCNN的深度学习方法
  4. 基于ESPCN的实时超分方法
  5. 使用预训练模型(推荐方法)
  6. 完整的实用示例
  7. 评估超分效果

我来详细讲解Python实现图像超分辨率(超分)的几种方法,从传统方法到深度学习。

环境准备

首先安装必要的库:

pip install opencv-python matplotlib numpy scikit-image
pip install torch torchvision  # 深度学习框架
pip install tensorflow         # 或TensorFlow

传统插值方法(基础版)

import cv2
import numpy as np
import matplotlib.pyplot as plt
def traditional_interpolation(image_path, scale_factor=2):
    """
    使用传统插值方法进行图像超分
    """
    # 读取图像
    img = cv2.imread(image_path)
    img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    # 获取原始尺寸
    h, w = img.shape[:2]
    new_h, new_w = h * scale_factor, w * scale_factor
    # 不同的插值方法
    methods = {
        'Nearest': cv2.INTER_NEAREST,
        'Linear': cv2.INTER_LINEAR,
        'Cubic': cv2.INTER_CUBIC,
        'Lanczos': cv2.INTER_LANCZOS4
    }
    results = {}
    for name, method in methods.items():
        upscaled = cv2.resize(img_rgb, (new_w, new_h), interpolation=method)
        results[name] = upscaled
    return results
# 使用示例
results = traditional_interpolation('input_image.jpg', scale_factor=2)
# 显示结果
fig, axes = plt.subplots(2, 2, figsize=(12, 12))
for idx, (name, img) in enumerate(results.items()):
    ax = axes[idx // 2, idx % 2]
    ax.imshow(img)
    ax.set_title(f'{name} Interpolation')
    ax.axis('off')
plt.show()

基于SRCNN的深度学习方法

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import numpy as np
from skimage import io, transform
class SRCNN(nn.Module):
    """
    SRCNN: Image Super-Resolution Using Deep Convolutional Networks
    """
    def __init__(self):
        super(SRCNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 64, kernel_size=9, padding=4)
        self.conv2 = nn.Conv2d(64, 32, kernel_size=1, padding=0)
        self.conv3 = nn.Conv2d(32, 1, kernel_size=5, padding=2)
        self.relu = nn.ReLU()
    def forward(self, x):
        x = self.relu(self.conv1(x))
        x = self.relu(self.conv2(x))
        x = self.conv3(x)
        return x
class SRDataset(Dataset):
    """
    自定义数据集类
    """
    def __init__(self, image_paths, scale_factor=2, patch_size=33):
        self.image_paths = image_paths
        self.scale_factor = scale_factor
        self.patch_size = patch_size
    def __len__(self):
        return len(self.image_paths)
    def __getitem__(self, idx):
        # 读取图像
        img = io.imread(self.image_paths[idx], as_gray=True)
        # 随机裁剪
        h, w = img.shape
        x = np.random.randint(0, h - self.patch_size * self.scale_factor)
        y = np.random.randint(0, w - self.patch_size * self.scale_factor)
        # 获取HR patch
        hr_patch = img[x:x+self.patch_size*self.scale_factor, 
                      y:y+self.patch_size*self.scale_factor]
        # 下采样获取LR patch
        lr_patch = transform.resize(hr_patch, (self.patch_size, self.patch_size))
        # 转换为tensor
        lr_tensor = torch.FloatTensor(lr_patch).unsqueeze(0)
        hr_tensor = torch.FloatTensor(hr_patch).unsqueeze(0)
        return lr_tensor, hr_tensor
def train_srcnn(model, train_loader, num_epochs=100):
    """
    训练SRCNN模型
    """
    criterion = nn.MSELoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    model.train()
    for epoch in range(num_epochs):
        running_loss = 0.0
        for i, (lr, hr) in enumerate(train_loader):
            optimizer.zero_grad()
            # 前向传播
            output = model(lr)
            loss = criterion(output, hr)
            # 反向传播
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        if (epoch + 1) % 10 == 0:
            print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(train_loader):.4f}')
# 使用示例
def srcnn_super_resolution(model, image_path, scale_factor=2):
    """
    使用训练好的SRCNN进行超分
    """
    # 读取并预处理图像
    img = io.imread(image_path, as_gray=True)
    h, w = img.shape
    # 双三次插值放大
    lr_img = transform.resize(img, (h//scale_factor, w//scale_factor))
    lr_upscaled = transform.resize(lr_img, (h, w))
    # 转换为tensor
    input_tensor = torch.FloatTensor(lr_upscaled).unsqueeze(0).unsqueeze(0)
    # 模型推理
    model.eval()
    with torch.no_grad():
        output = model(input_tensor)
    # 转换回numpy
    sr_img = output.squeeze().numpy()
    return sr_img, lr_upscaled

基于ESPCN的实时超分方法

class ESPCN(nn.Module):
    """
    ESPCN: Real-Time Single Image and Video Super-Resolution
    """
    def __init__(self, scale_factor=2):
        super(ESPCN, self).__init__()
        self.scale_factor = scale_factor
        self.conv1 = nn.Conv2d(1, 64, kernel_size=5, padding=2)
        self.conv2 = nn.Conv2d(64, 32, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(32, scale_factor ** 2, kernel_size=3, padding=1)
        self.relu = nn.ReLU()
        self.pixel_shuffle = nn.PixelShuffle(scale_factor)
    def forward(self, x):
        x = self.relu(self.conv1(x))
        x = self.relu(self.conv2(x))
        x = self.conv3(x)
        x = self.pixel_shuffle(x)
        return x

使用预训练模型(推荐方法)

# 方法1:使用OpenCV DNN模块
def opencv_dnn_super_resolution(image_path, model_path):
    """
    使用OpenCV的DNN超分模块
    需要下载预训练模型:https://github.com/opencv/opencv_contrib/tree/master/modules/dnn_superres
    """
    import cv2.dnn_superres
    # 创建超分对象
    sr = cv2.dnn_superres.DnnSuperResImpl_create()
    # 读取模型
    sr.readModel(model_path)
    sr.setModel("edsr", 2)  # 或 "fsrcnn", "lapsrn"
    # 读取图像
    img = cv2.imread(image_path)
    # 执行超分
    result = sr.upsample(img)
    return result
# 方法2:使用Hugging Face的预训练模型
def huggingface_super_resolution(image):
    """
    使用Hugging Face的预训练超分模型
    """
    from transformers import AutoImageProcessor, Swin2SRForImageSuperResolution
    from PIL import Image
    import torch
    # 加载模型和处理器
    processor = AutoImageProcessor.from_pretrained("caidas/swin2SR-classical-sr-x2-v2")
    model = Swin2SRForImageSuperResolution.from_pretrained("caidas/swin2SR-classical-sr-x2-v2")
    # 预处理图像
    inputs = processor(image, return_tensors="pt")
    # 推理
    with torch.no_grad():
        outputs = model(**inputs)
    # 后处理
    sr_image = processor.post_process(outputs, target_sizes=[image.size[::-1]])[0]
    return sr_image

完整的实用示例

import cv2
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
class ImageSuperResolver:
    """
    图像超分辨率完整实现
    """
    def __init__(self):
        self.methods = {
            'bicubic': self._bicubic_upscale,
            'srcnn': self._srcnn_upscale,
            'espcn': self._espcn_upscale
        }
    def _bicubic_upscale(self, img, scale_factor):
        """双三次插值"""
        h, w = img.shape[:2]
        return cv2.resize(img, (w*scale_factor, h*scale_factor), 
                         interpolation=cv2.INTER_CUBIC)
    def _srcnn_upscale(self, img, scale_factor):
        """SRCNN超分"""
        # 可以在这里加载预训练的SRCNN模型
        model = SRCNN()
        # 加载权重...
        return self._bicubic_upscale(img, scale_factor)  # 暂时使用插值
    def _espcn_upscale(self, img, scale_factor):
        """ESPCN超分"""
        # 加载预训练的ESPCN模型
        model = ESPCN(scale_factor)
        # 加载权重...
        return self._bicubic_upscale(img, scale_factor)  # 暂时使用插值
    def upscale(self, image_path, scale_factor=2, method='bicubic'):
        """
        执行图像超分
        """
        # 读取图像
        img = cv2.imread(image_path)
        img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        # 执行超分
        if method in self.methods:
            result = self.methods[method](img_rgb, scale_factor)
        else:
            raise ValueError(f"Unknown method: {method}")
        return result
    def compare_methods(self, image_path, scale_factor=2):
        """
        比较不同超分方法的效果
        """
        img = cv2.imread(image_path)
        img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        results = {}
        for name, func in self.methods.items():
            results[name] = func(img_rgb, scale_factor)
        # 显示比较结果
        fig, axes = plt.subplots(2, 2, figsize=(15, 15))
        # 原图
        axes[0, 0].imshow(img_rgb)
        axes[0, 0].set_title(f'Original ({img.shape[1]}x{img.shape[0]})')
        axes[0, 0].axis('off')
        for idx, (name, result) in enumerate(results.items()):
            ax = axes[(idx + 1) // 2, (idx + 1) % 2]
            ax.imshow(result)
            ax.set_title(f'{name.upper()} ({result.shape[1]}x{result.shape[0]})')
            ax.axis('off')
        plt.tight_layout()
        plt.show()
        return results
# 使用示例
def main():
    # 创建超分器
    resolver = ImageSuperResolver()
    # 单方法超分
    result = resolver.upscale('input.jpg', scale_factor=2, method='bicubic')
    cv2.imwrite('output.jpg', cv2.cvtColor(result, cv2.COLOR_RGB2BGR))
    # 比较不同方法
    resolver.compare_methods('input.jpg', scale_factor=2)
    # 批量处理
    import os
    input_dir = 'input_images/'
    output_dir = 'output_images/'
    os.makedirs(output_dir, exist_ok=True)
    for img_name in os.listdir(input_dir):
        if img_name.endswith(('.jpg', '.png', '.jpeg')):
            input_path = os.path.join(input_dir, img_name)
            result = resolver.upscale(input_path, scale_factor=3)
            output_path = os.path.join(output_dir, f'sr_{img_name}')
            cv2.imwrite(output_path, cv2.cvtColor(result, cv2.COLOR_RGB2BGR))
            print(f'Processed: {img_name}')
if __name__ == '__main__':
    main()

评估超分效果

from skimage.metrics import peak_signal_noise_ratio as psnr
from skimage.metrics import structural_similarity as ssim
def evaluate_super_resolution(original, super_resolved):
    """
    评估超分效果
    """
    # 确保图像尺寸一致
    if original.shape != super_resolved.shape:
        super_resolved = cv2.resize(super_resolved, 
                                    (original.shape[1], original.shape[0]))
    # 计算PSNR
    psnr_value = psnr(original, super_resolved, data_range=255)
    # 计算SSIM
    ssim_value = ssim(original, super_resolved, multichannel=True, data_range=255)
    return {
        'PSNR': psnr_value,
        'SSIM': ssim_value
    }
# 使用示例
def demo_evaluation():
    # 读取原图和超分结果
    original = cv2.imread('original.jpg')
    super_resolved = cv2.imread('super_resolved.jpg')
    # 评估
    metrics = evaluate_super_resolution(original, super_resolved)
    print(f"PSNR: {metrics['PSNR']:.2f} dB")
    print(f"SSIM: {metrics['SSIM']:.4f}")
  1. 传统方法:使用OpenCV的插值算法,简单快速但效果一般
  2. 深度学习方法:SRCNN、ESPCN等,需要训练但效果更好
  3. 预训练模型:使用OpenCV DNN或Hugging Face的模型,最实用
  4. 实际应用:建议使用预训练的深度学习模型,平衡效果和效率

对于大多数应用场景,推荐使用预训练的深度超分模型,它们在效果和速度上都表现良好。

抱歉,评论功能暂时关闭!