Source code for sldl.image.denoising

import torch
from torch import nn
from .swinir import SwinIR, swin_ir_inference
from PIL import Image

from sldl._utils import get_checkpoint_path


[docs]class ImageDenoising(nn.Module): r"""Image Denoising Takes an image and removes the noise from it. Currently supports SwinIR only. :param model_name: Name of the pre-trained model. Now it can only be `SwinIR`. :type model_name: str :param noise: Noise level that the model was trained on. Can be of of `15`, `25`, `50`. :type noise: int :param precision: Can be either `full` (uses fp32) and `half` (uses fp16). Default: `full`. :type precision: str. Example: .. code-block:: python from PIL import Image from sldl.image import ImageDenoising denoiser = ImageDenoising('SwinIR') img = Image.open('test.png') denoised = denoiser(img) """ def __init__( self, model_name: str = "SwinIR", noise: int = 15, precision: str = "full" ): super(ImageDenoising, self).__init__() self.model_name = model_name self.precision = precision if model_name == "SwinIR": self.model = SwinIR( upscale=1, in_chans=3, img_size=128, window_size=8, img_range=1.0, depths=[6, 6, 6, 6, 6, 6], embed_dim=180, num_heads=[6, 6, 6, 6, 6, 6], mlp_ratio=2, upsampler="", resi_connection="1conv", ) path = get_checkpoint_path( f"https://github.com/JingyunLiang/SwinIR/releases/download/v0.0/005_colorDN_DFWB_s128w8_SwinIR-M_noise{noise}.pth" ) pretrained_model = torch.load(path) self.model.load_state_dict( pretrained_model["params"] if "params" in pretrained_model.keys() else pretrained_model, strict=True, ) if self.precision == "half": self.model = self.model.half() @property def device(self) -> torch.device: return next(self.parameters()).device
[docs] def __call__(self, img: Image.Image) -> Image.Image: """ Applies the denoiser. Args: img (PIL.Image.Image): An input image. Returns: PIL.Image.Image: Denoised image. """ return swin_ir_inference( self.model, img, device=self.device, precision=self.precision )