Source code for mammopy.pre_process

import numpy as np
import requests
from pathlib import Path
import torch
import torch.nn as nn
from torchvision import transforms
from PIL import Image

[docs]class pre_process():
[docs] def load_model(model_path): """ Loads a pre-trained model from a given path. Args: - model_path: str, the path to the pre-trained model. If model_path = 'base', then the default pre-trained model will be loaded. Returns: - model: torch.nn.Module, the pre-trained model loaded from the specified path. Raises: - TypeError: if the input model_path is not 'base'. """ if model_path == 'base': # Define the path to the default pre-trained model weights model_weights_path = "weights.pth" path = Path(model_weights_path) # If the model weights file does not exist, download it from Dropbox if not path.is_file(): url = "" response = requests.get(url) open("weights.pth", "wb").write(response.content) # Load the pre-trained model and set the device to CPU model = torch.load(model_weights_path, map_location=torch.device("cpu")) # Convert the model to be compatible with multiple GPUs, if available model = nn.DataParallel(model.module) return model else: raise TypeError("Model not implemented")
[docs] def image_tensor(img): """ Converts a PIL Image or NumPy array to a PyTorch tensor. Args: - img: PIL.Image or np.ndarray, the image to be converted to a PyTorch tensor. Returns: - image: torch.Tensor, the converted PyTorch tensor. Raises: - TypeError: if the input image is not a PIL.Image or np.ndarray. """ if type(img) not in [np.ndarray, Image.Image]: raise TypeError("Input must be np.ndarray or PIL.Image") # Define a PyTorch tensor transformer pipeline torch_tensor = transforms.Compose( [transforms.Resize((256, 256)), transforms.ToTensor()] ) if type(img) == Image.Image: # Convert PIL image to PyTorch tensor image = torch_tensor(img) image = image.unsqueeze(0) #print("tensor", type(image)) return image elif type(img) == np.ndarray: # Convert NumPy array to RGB PIL image and then to PyTorch tensor pil_image = Image.fromarray(img).convert("RGB") image = torch_tensor(pil_image) image = image.unsqueeze(0) #print("tensor", type(image)) return image else: raise TypeError("Input must be np.ndarray or PIL.Image")