Source code for heavyedge_dataset

"""Package to load edge profile data using PyTorch dataset."""

import numbers
from collections.abc import Sequence

import numpy as np
from heavyedge_landmarks import (
    landmarks_type1,
    landmarks_type2,
    landmarks_type3,
    pseudo_landmarks,
)
from torch.utils.data import Dataset

__all__ = [
    "ProfileDataset",
    "PseudoLandmarkDataset",
    "MathematicalLandmarkDataset",
]


[docs] class ProfileDataset(Dataset): """Edge profile dataset. Loads data as a tuple of two numpy arrays: 1. Profile data, shape: (N, m, L). 2. Length of each profile, shape: (N,). N is the number of loaded data, m is dimension of coordinates, and L is the maximum length of profiles. Parameters ---------- file : heavyedge.ProfileData Open hdf5 file. m : {1, 2} Profile data dimension. 1 means only y coordinates, and 2 means both x and y coordinates. transform : callable, optional Optional transformation to be applied on samples. Examples -------- >>> from heavyedge import get_sample_path, ProfileData >>> from heavyedge_dataset import ProfileDataset >>> with ProfileData(get_sample_path("Prep-Type2.h5")) as file: ... profiles, lengths = ProfileDataset(file, m=1)[:] >>> profiles.shape (22, 1, 3200) >>> with ProfileData(get_sample_path("Prep-Type2.h5")) as file: ... profiles, lengths = ProfileDataset(file, m=2)[:] >>> profiles.shape (22, 2, 3200) >>> lengths.shape (22,) >>> import matplotlib.pyplot as plt # doctest: +SKIP ... plt.plot(*profiles.transpose(1, 2, 0)) """ def __init__(self, file, m=1, transform=None): self.file = file self.m = m self.transform = transform self.x = file.x() def __len__(self): return len(self.file) def __getitem__(self, idx): if isinstance(idx, numbers.Integral): Y, L, _ = self.file[idx] Y = Y[np.newaxis, :] else: # Support multi-indexing idxs = idx needs_sort = isinstance(idx, (Sequence, np.ndarray)) if needs_sort: # idxs must be sorted for h5py idxs = np.array(idxs) sort_idx = np.argsort(idxs) idxs = idxs[sort_idx] Y, L, _ = self.file[idxs] if needs_sort: reverse_idx = np.argsort(sort_idx) Y = Y[reverse_idx] L = L[reverse_idx] Y = Y[:, np.newaxis, :] if self.m == 1: pass elif self.m == 2: x = np.tile(self.x, Y.shape[:-1] + (1,)) Y = np.concatenate([x, Y], axis=-2) else: raise ValueError(f"Unsupported dimension: {self.m} (Must be 1 or 2).") ret = (Y, L) if self.transform is not None: ret = self.transform(ret) return ret def __getitems__(self, idxs): # PyTorch API return self.__getitem__(idxs)
[docs] class PseudoLandmarkDataset(Dataset): """Dataset for pseudo-landmarks of edge profiles Parameters ---------- file : heavyedge.ProfileData Open hdf5 file. m : {1, 2} Dimension of landmark coordinates. k : int Number of landmarks to sample. transform : callable, optional Optional transformation to be applied on samples. Examples -------- >>> from heavyedge import ProfileData, get_sample_path >>> from heavyedge_dataset import PseudoLandmarkDataset >>> with ProfileData(get_sample_path("Prep-Type1.h5")) as file: ... dataset = PseudoLandmarkDataset(file, 1, 10) ... data = dataset[:] >>> data.shape (18, 1, 10) >>> import matplotlib.pyplot as plt # doctest: +SKIP ... plt.plot(*data.transpose(1, 2, 0)) """ def __init__(self, file, m, k, transform=None): self.profiles = ProfileDataset(file, m=1) self.m = m self.k = k self.transform = transform def __len__(self): return len(self.profiles) def __getitem__(self, idx): if isinstance(idx, numbers.Integral): Y, L = self.profiles[idx] Ys, Ls = Y[np.newaxis, ...], L[np.newaxis, ...] else: Ys, Ls = self.profiles[idx] Ys = Ys.squeeze(axis=1) ret = pseudo_landmarks(self.profiles.x, Ys, Ls, self.k) if self.m == 1: ret = ret[:, 1:2, :] if isinstance(idx, numbers.Integral): ret = ret[0] if self.transform is not None: ret = self.transform(ret) return ret def __getitems__(self, idxs): # PyTorch API return self.__getitem__(idxs)
[docs] class MathematicalLandmarkDataset(Dataset): """Dataset for mathematical landmarks of edge profiles. Loads data as a tuple of two numpy arrays: 1. Landmark coordinates, shape: (N, m, k). 2. Average plateau height, shape: (N,). N is the number of loaded data, m is dimension of coordinates, and k is the number of mathematical landmarks detected; k=5 for type 3 profiles and k=4 for type 1 and type 2 profiles. Parameters ---------- file : heavyedge.ProfileData Open hdf5 file. m : {1, 2} Dimension of landmark coordinates. sigma : scalar Standard deviation of Gaussian kernel for landmark detection. ptype : {1, 2, 3}, default=3 Assumed type of edge profiles. transform : callable, optional Optional transformation to be applied on samples. Notes ----- Unlike HeavyEdge-Landmarks package, landmark points returned by this dataset are sorted by ascending X coordinates. Additionally, the point at X=0 is included as the first landmark. Examples -------- >>> from heavyedge import ProfileData, get_sample_path >>> from heavyedge_dataset import MathematicalLandmarkDataset >>> with ProfileData(get_sample_path("Prep-Type3.h5")) as file: ... dataset = MathematicalLandmarkDataset(file, 1, 32) ... landmarks, height = dataset[:] >>> landmarks.shape (35, 1, 5) >>> height.shape (35,) >>> with ProfileData(get_sample_path("Prep-Type3.h5")) as file: ... dataset = MathematicalLandmarkDataset(file, 2, 32) ... landmarks, height = dataset[:] >>> landmarks.shape (35, 2, 5) >>> import matplotlib.pyplot as plt # doctest: +SKIP ... plt.plot(*landmarks.transpose(1, 2, 0)) """ def __init__(self, file, m, sigma, ptype=3, transform=None): self.profiles = ProfileDataset(file, m=1) self.m = m self.sigma = sigma self.ptype = ptype self.transform = transform def __len__(self): return len(self.profiles) def __getitem__(self, idx): x = self.profiles.x if isinstance(idx, numbers.Integral): Y, L = self.profiles[idx] Ys, Ls = Y[np.newaxis, ...], L[np.newaxis, ...] else: Ys, Ls = self.profiles[idx] Ys = Ys.squeeze(axis=1) if self.ptype == 1: lm = landmarks_type1(x, Ys, Ls, self.sigma) elif self.ptype == 2: lm = landmarks_type2(x, Ys, Ls, self.sigma) elif self.ptype == 3: lm = landmarks_type3(x, Ys, Ls, self.sigma) else: raise ValueError( f"Unsupported profile type: {self.ptype} (Must be 1, 2, or 3)." ) X = np.flip(lm, axis=-1) # Prepend point at x=0 x_zeros = np.stack([np.full((len(Ys),), x[0]), Ys[:, 0]], axis=1) X = np.concatenate([x_zeros[..., np.newaxis], X], axis=-1) # Find average plateau height knee_idxs = np.searchsorted(x, X[:, 0, 1]) H = [] for Y, knee_idx in zip(Ys, knee_idxs): H.append(np.mean(Y[:knee_idx])) if self.m == 1: X = X[:, 1:2, :] if isinstance(idx, numbers.Integral): ret = (X[0], np.array(H)[0]) else: ret = (X, np.array(H)) if self.transform is not None: ret = self.transform(ret) return ret def __getitems__(self, idxs): # PyTorch API return self.__getitem__(idxs)