diff --git a/reconstruction/jmlr/augs.py b/reconstruction/jmlr/augs.py new file mode 100644 index 0000000..2977afb --- /dev/null +++ b/reconstruction/jmlr/augs.py @@ -0,0 +1,271 @@ +import numpy as np +import cv2 +import os +import os.path as osp +import albumentations as A +from albumentations.core.transforms_interface import ImageOnlyTransform +from albumentations.pytorch import ToTensorV2 + +class RectangleBorderAugmentation(ImageOnlyTransform): + + def __init__( + self, + fill_value = 0, + fg_limit = (0.7, 0.9), + always_apply=False, + p=1.0, + ): + super(RectangleBorderAugmentation, self).__init__(always_apply, p) + #assert limit>0.0 and limit<1.0 + assert isinstance(fg_limit, tuple) + assert fg_limit[1]>fg_limit[0] + self.fill_value = 0 + self.fg_limit = fg_limit + #self.output_size = output_size + + + def apply(self, image, fg, top, left, **params): + assert image.shape[0]==image.shape[1] + oimage = np.ones_like(image) * self.fill_value + f = int(fg*image.shape[0]) + t = int(top*image.shape[0]) + l = int(left*image.shape[1]) + oimage[t:t+f,l:l+f,:] = image[t:t+f,l:l+f,:] + return oimage + + def get_params(self): + fg = np.random.uniform(self.fg_limit[0], self.fg_limit[1]) + top = np.random.uniform(0.0, 1.0-fg) + left = np.random.uniform(0.0, 1.0-fg) + return {'fg': fg, 'top': top, 'left': left} + + def get_transform_init_args_names(self): + return ('fill_value','fg_limit') + +class SunGlassAugmentation(ImageOnlyTransform): + + def __init__( + self, + fill_value = 0, + loc = [ (38, 52), (73, 52) ], + rad_limit = (10, 20), + always_apply=False, + p=1.0, + ): + super(SunGlassAugmentation, self).__init__(always_apply, p) + #assert limit>0.0 and limit<1.0 + assert isinstance(rad_limit, tuple) + self.fill_value = 0 + self.loc = loc + self.rad_limit = rad_limit + + + def apply(self, image, rad, **params): + for i in range(2): + cv2.circle(image, self.loc[i], rad, self.fill_value, -1) + return image + + def get_params(self): + rad = np.random.randint(self.rad_limit[0], self.rad_limit[1]) + return {'rad':rad} + + def get_transform_init_args_names(self): + return ('fill_value', 'loc', 'rad_limit') + +class ForeHeadAugmentation(ImageOnlyTransform): + + def __init__( + self, + height_min = 0.2, + height_max = 0.4, + width_min = 0.5, + always_apply=False, + p=1.0, + ): + super(ForeHeadAugmentation, self).__init__(always_apply, p) + assert height_max > height_min + #assert limit>0.0 and limit<1.0 + self.height_min = height_min + self.height_max = height_max + self.width_min = width_min + + + def apply(self, image, height, width, left, **params): + mask_value = np.random.randint(0, 255, size=(int(image.shape[0]*height), int(image.shape[1]*width), 3), dtype=image.dtype) + l = int(image.shape[1]*left) + image[:mask_value.shape[0], l:l+mask_value.shape[1], :] = mask_value + return image + + def get_params(self): + height = np.random.uniform(self.height_min, self.height_max) + width = np.random.uniform(self.width_min, 1.0) + left = np.random.uniform(0.0, 1.0 - width) + return {'height': height, 'width': width, 'left': left} + + def get_transform_init_args_names(self): + return ('height_min', 'height_max','width_min') + + +def get_aug_transform(cfg): + aug_modes = cfg.aug_modes + input_size = cfg.input_size + task = cfg.task + transform_list = [] + is_test = False + if 'test-aug' in aug_modes: + #transform_list.append( + # A.RandomBrightnessContrast(brightness_limit=0.125, contrast_limit=0.05, p=0.2) + # ) + transform_list.append( + A.ShiftScaleRotate(shift_limit=0.02, scale_limit=0.05, rotate_limit=5, interpolation=cv2.INTER_LINEAR, + border_mode=cv2.BORDER_CONSTANT, value=0, mask_value=0, p=1.0, always_apply=True) + ) + is_test = True + + if '1' in aug_modes: + transform_list.append( + A.RandomBrightnessContrast(brightness_limit=0.125, contrast_limit=0.05, p=0.2) + ) + if '1A' in aug_modes: + transform_list.append( + A.RandomBrightnessContrast(brightness_limit=0.125, contrast_limit=0.05, p=0.2) + ) + transform_list.append( + A.ShiftScaleRotate(shift_limit=0.02, scale_limit=0.03, rotate_limit=6, interpolation=cv2.INTER_LINEAR, + border_mode=cv2.BORDER_CONSTANT, value=0, mask_value=0, p=0.3) + ) + if '2' in aug_modes: + transform_list.append( + A.RandomBrightnessContrast(brightness_limit=0.125, contrast_limit=0.05, p=0.2) + ) + transform_list.append( + A.ShiftScaleRotate(shift_limit=0.05, scale_limit=0.1, rotate_limit=15, interpolation=cv2.INTER_LINEAR, + border_mode=cv2.BORDER_CONSTANT, value=0, mask_value=0, p=0.4) + ) + if '3' in aug_modes: + transform_list.append( + A.RandomBrightnessContrast(brightness_limit=0.125, contrast_limit=0.05, p=0.6) + ) + transform_list.append( + A.ShiftScaleRotate(shift_limit=0.1, scale_limit=0.2, rotate_limit=30, interpolation=cv2.INTER_LINEAR, + border_mode=cv2.BORDER_CONSTANT, value=0, mask_value=0, p=0.6) + ) + if 'nist1' in aug_modes: + transform_list.append( + A.RandomBrightnessContrast(brightness_limit=0.125, contrast_limit=0.05, p=0.2) + ) + transform_list.append( + A.ShiftScaleRotate(shift_limit=0.05, scale_limit=0.06, rotate_limit=6, interpolation=cv2.INTER_LINEAR, + border_mode=cv2.BORDER_CONSTANT, value=0, mask_value=0, p=0.4) + ) + if 'nist2' in aug_modes: + transform_list.append( + #A.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.02, p=0.3) + A.RandomBrightnessContrast(brightness_limit=0.125, contrast_limit=0.05, p=0.2) + ) + transform_list.append( + A.ShiftScaleRotate(shift_limit=0.06, scale_limit=0.06, rotate_limit=6, interpolation=cv2.INTER_LINEAR, + border_mode=cv2.BORDER_CONSTANT, value=0, mask_value=0, p=0.4) + ) + transform_list.append( + A.OneOf([ + RectangleBorderAugmentation(p=0.5), + ForeHeadAugmentation(p=0.5), + #SunGlassAugmentation(p=0.2), + ], p=0.06) + ) + transform_list.append( + A.ToGray(p=0.05) + ) + transform_list.append( + A.geometric.resize.RandomScale(scale_limit=(0.7, 0.9), interpolation=cv2.INTER_LINEAR, p=0.05) + ) + transform_list.append( + A.ISONoise(p=0.06) + ) + transform_list.append( + A.MedianBlur(blur_limit=(1,7), p=0.05) + ) + transform_list.append( + A.MotionBlur(blur_limit=(5,12), p=0.05) + ) + transform_list.append( + A.ImageCompression(quality_lower=50, quality_upper=80, p=0.05) + ) + if 'prod' in aug_modes: + transform_list.append( + #A.RandomBrightnessContrast(brightness_limit=0.125, contrast_limit=0.125, p=0.2) + A.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.02, p=0.3) + ) + transform_list.append( + A.ShiftScaleRotate(shift_limit=0.06, scale_limit=0.1, rotate_limit=10, interpolation=cv2.INTER_LINEAR, + border_mode=cv2.BORDER_CONSTANT, value=0, mask_value=0, p=0.4) + ) + transform_list.append( + A.OneOf([ + RectangleBorderAugmentation(p=0.5), + ForeHeadAugmentation(p=0.5), + MaskAugmentation(mask_names=['mask_white', 'mask_blue', 'mask_black', 'mask_green'], mask_probs=[0.4, 0.4, 0.1, 0.1], h_low=0.33, h_high=0.4, p=0.2), + SunGlassAugmentation(p=0.2), + ], p=0.2) + ) + transform_list.append( + A.ToGray(p=0.05) + ) + transform_list.append( + A.geometric.resize.RandomScale(scale_limit=(0.6, 0.9), interpolation=cv2.INTER_LINEAR, p=0.2) + ) + transform_list.append( + A.ISONoise(p=0.1) + ) + transform_list.append( + A.MedianBlur(blur_limit=(1,7), p=0.1) + ) + transform_list.append( + A.MotionBlur(blur_limit=(5,12), p=0.1) + ) + transform_list.append( + A.ImageCompression(quality_lower=30, quality_upper=80, p=0.1) + ) + #if input_size!=112: # TODO!! + # transform_list.append( + # A.geometric.resize.Resize(input_size, input_size, interpolation=cv2.INTER_LINEAR, always_apply=True) + # ) + transform_list += \ + [ + #A.HorizontalFlip(p=0.5), + A.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), + ToTensorV2(), + ] + #here, the input for A transform is rgb cv2 img + if is_test: + transform = A.ReplayCompose( + transform_list , + keypoint_params=A.KeypointParams(format='xy',remove_invisible=False) + ) + else: + transform = A.Compose( + transform_list, + keypoint_params=A.KeypointParams(format='xy',remove_invisible=False) + ) + return transform + + +if __name__ == "__main__": + tool = MaskRenderer() + tool.prepare(ctx_id=0, det_size=(128,128)) + image = cv2.imread("./test1.png")[:,:,::-1] + mask_image = "mask_blue" + #params = tool.build_params(image) + label = np.load('assets/mask_label.npy') + params = tool.decode_params(label) + #print(params[0][:20]) + mask_out = tool.render_mask(image, mask_image, params, input_is_rgb=True, auto_blend=False)[:,:,::-1] + #print(uv_out.dtype, uv_out.shape) + cv2.imwrite('output_mask.jpg', mask_out) + transform = A.Compose([ + MaskAugmentation(mask_names=['mask_white', 'mask_blue', 'mask_black', 'mask_green'], mask_probs=[0.4, 0.4, 0.1, 0.1], h_low=0.33, h_high=0.4, p=1.0), + #MaskAugmentation(p=1.0), + ]) + mask_out = transform(image=image, hlabel=label)["image"][:,:,::-1] + cv2.imwrite('output_mask2.jpg', mask_out) diff --git a/reconstruction/jmlr/backbones/__init__.py b/reconstruction/jmlr/backbones/__init__.py new file mode 100644 index 0000000..5f46a9d --- /dev/null +++ b/reconstruction/jmlr/backbones/__init__.py @@ -0,0 +1 @@ +from .network import get_network diff --git a/reconstruction/jmlr/backbones/iresnet.py b/reconstruction/jmlr/backbones/iresnet.py new file mode 100644 index 0000000..2103c1e --- /dev/null +++ b/reconstruction/jmlr/backbones/iresnet.py @@ -0,0 +1,326 @@ +import torch +from torch import nn +import torch.nn.functional as F +import logging + +__all__ = ['iresnet18', 'iresnet34', 'iresnet50', 'iresnet100', 'iresnet200'] + + +def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): + """3x3 convolution with padding""" + return nn.Conv2d(in_planes, + out_planes, + kernel_size=3, + stride=stride, + padding=dilation, + groups=groups, + bias=False, + dilation=dilation) + + +def conv1x1(in_planes, out_planes, stride=1): + """1x1 convolution""" + return nn.Conv2d(in_planes, + out_planes, + kernel_size=1, + stride=stride, + bias=False) + + +class IBasicBlock(nn.Module): + expansion = 1 + def __init__(self, inplanes, planes, stride=1, downsample=None, + groups=1, base_width=64, dilation=1, eps=1e-5, dropblock=0.0): + super(IBasicBlock, self).__init__() + if groups != 1 or base_width != 64: + raise ValueError('BasicBlock only supports groups=1 and base_width=64') + if dilation > 1: + raise NotImplementedError("Dilation > 1 not supported in BasicBlock") + self.bn1 = nn.BatchNorm2d(inplanes, eps=eps) + self.conv1 = conv3x3(inplanes, planes) + self.bn2 = nn.BatchNorm2d(planes, eps=eps) + self.prelu = nn.PReLU(planes) + self.conv2 = conv3x3(planes, planes, stride) + self.bn3 = nn.BatchNorm2d(planes, eps=eps) + self.downsample = downsample + self.stride = stride + self.dbs = None + if dropblock>0.0: + import timm + from timm.layers import DropBlock2d + self.dbs = [DropBlock2d(dropblock, 7), DropBlock2d(dropblock, 7), DropBlock2d(dropblock, 7)] + + def forward(self, x): + identity = x + out = self.bn1(x) + out = self.conv1(out) + if self.dbs is not None: + out = self.dbs[0](out) + out = self.bn2(out) + out = self.prelu(out) + out = self.conv2(out) + if self.dbs is not None: + out = self.dbs[1](out) + out = self.bn3(out) + if self.downsample is not None: + identity = self.downsample(x) + out += identity + if self.dbs is not None: + out = self.dbs[2](out) + return out + + +class IResNet(nn.Module): + def __init__(self, + block, layers, dropout=0.0, num_features=512, input_size=112, zero_init_residual=False, + stem_type='', dropblock = 0.0, kaiming_init=False, + groups=1, width_per_group=64, replace_stride_with_dilation=None, fp16=0): + super(IResNet, self).__init__() + self.input_size = input_size + assert self.input_size%16==0 + fc_scale = self.input_size // 16 + self.fc_scale = fc_scale*fc_scale + self.fp16 = fp16 + self.inplanes = 64 + self.dilation = 1 + self.norm_layer = nn.BatchNorm2d + self.act_layer = nn.PReLU + self.eps = 1e-5 + if kaiming_init: + self.eps = 2e-5 + self.stem_type = stem_type + self.dropblock = dropblock + if replace_stride_with_dilation is None: + replace_stride_with_dilation = [False, False, False] + if len(replace_stride_with_dilation) != 3: + raise ValueError("replace_stride_with_dilation should be None " + "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) + self.groups = groups + self.base_width = width_per_group + if stem_type!='D': + self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False) + else: + stem_width = self.inplanes // 2 + stem_chs = [stem_width, stem_width] + self.conv1 = nn.Sequential(*[ + nn.Conv2d(3, stem_chs[0], 3, stride=1, padding=1, bias=False), + self.norm_layer(stem_chs[0], eps=self.eps), + self.act_layer(stem_chs[0]), + nn.Conv2d(stem_chs[0], stem_chs[1], 3, stride=1, padding=1, bias=False), + self.norm_layer(stem_chs[1], eps=self.eps), + self.act_layer(stem_chs[1]), + nn.Conv2d(stem_chs[1], self.inplanes, 3, stride=1, padding=1, bias=False)]) + logging.info("iresnet, input_size: %d, fc_scale: %d, dropout: %.2f, stem_type: %s, fp16: %d"%(self.input_size, self.fc_scale, dropout, stem_type, self.fp16)) + logging.info("iresnet, eps: %.6f, dropblock: %.3f, kaiming_init: %d"%(self.eps, self.dropblock, kaiming_init)) + #self.conv1.requires_grad = False + self.bn1 = nn.BatchNorm2d(self.inplanes, eps=self.eps) + #self.bn1.requires_grad = False + self.prelu = nn.PReLU(self.inplanes) + #self.prelu.requires_grad = False + self.layer1 = self._make_layer(block, 64, layers[0], stride=2) + #self.layer1.requires_grad = False + self.layer2 = self._make_layer(block, + 128, + layers[1], + stride=2, + dilate=replace_stride_with_dilation[0]) + #self.layer2.requires_grad = False + self.layer3 = self._make_layer(block, + 256, + layers[2], + stride=2, + dilate=replace_stride_with_dilation[1], + dropblock=self.dropblock) + #self.layer3.requires_grad = False + self.layer4 = self._make_layer(block, + 512, + layers[3], + stride=2, + dilate=replace_stride_with_dilation[2], + dropblock=self.dropblock) + #self.layer4.requires_grad = False + self.bn2 = nn.BatchNorm2d(512 * block.expansion, eps=self.eps) + #self.bn2.requires_grad = False + if dropout>0.0: + self.dropout = nn.Dropout(p=dropout, inplace=True) + else: + self.dropout = None + self.fc = nn.Linear(512 * block.expansion * self.fc_scale, num_features) + #self.fc.requires_grad = False + self.features = nn.BatchNorm1d(num_features, eps=self.eps) + nn.init.constant_(self.features.weight, 1.0) + self.features.weight.requires_grad = False + + #for m in self.modules(): + # if isinstance(m, nn.Conv2d): + # nn.init.xavier_uniform_(m.weight.data) + # if m.bias is not None: + # m.bias.data.zero_() + # elif isinstance(m, nn.BatchNorm2d): + # m.weight.data.fill_(1) + # m.bias.data.zero_() + # elif isinstance(m, nn.BatchNorm1d): + # m.weight.data.fill_(1) + # m.bias.data.zero_() + # elif isinstance(m, nn.Linear): + # nn.init.xavier_uniform_(m.weight.data) + # if m.bias is not None: + # m.bias.data.zero_() + #nn.init.constant_(self.features.weight, 1.0) + #self.features.weight.requires_grad = False + + #for m in self.modules(): + # if kaiming_init: + # if isinstance(m, (nn.Conv2d, nn.Linear)): + # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + # if m.bias is not None: + # nn.init.constant_(m.bias, 0) + # else: + # if isinstance(m, (nn.Conv2d, nn.Linear)): + # nn.init.normal_(m.weight, 0, 0.1) + # if m.bias is not None: + # nn.init.constant_(m.bias, 0) + # if isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): + # nn.init.constant_(m.weight, 1) + # nn.init.constant_(m.bias, 0) + + def _make_layer(self, block, planes, blocks, stride=1, dilate=False, dropblock=0.0): + downsample = None + previous_dilation = self.dilation + if dilate: + self.dilation *= stride + stride = 1 + if stride != 1 or self.inplanes != planes * block.expansion: + if self.stem_type!='D': + downsample = nn.Sequential( + conv1x1(self.inplanes, planes * block.expansion, stride), + nn.BatchNorm2d(planes * block.expansion, eps=self.eps), + ) + else: + #avg_pool_fn = AvgPool2dSame if avg_stride == 1 and dilation > 1 else nn.AvgPool2d + avg_stride = stride + pool = nn.AvgPool2d(2, avg_stride, ceil_mode=True, count_include_pad=False) + downsample = nn.Sequential(*[ + pool, + conv1x1(self.inplanes, planes * block.expansion, stride=1), + nn.BatchNorm2d(planes * block.expansion, eps=self.eps), + ]) + layers = [] + layers.append( + block(self.inplanes, planes, stride, downsample, self.groups, + self.base_width, previous_dilation, self.eps, dropblock)) + self.inplanes = planes * block.expansion + for _ in range(1, blocks): + layers.append( + block(self.inplanes, + planes, + groups=self.groups, + base_width=self.base_width, + dilation=self.dilation, + eps=self.eps, + dropblock=dropblock)) + + return nn.Sequential(*layers) + + def forward(self, x): + #if self.input_size!=112: + # x = F.interpolate(x, [self.input_size, self.input_size], mode='bilinear', align_corners=False) + is_fp16 = self.fp16>0 + with torch.cuda.amp.autocast(is_fp16): + x = self.conv1(x) + x = self.bn1(x) + x = self.prelu(x) + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + if self.fp16<3: + x = self.layer4(x) + x = self.bn2(x) + x = torch.flatten(x, 1) + if self.dropout is not None: + x = self.dropout(x) + if is_fp16: + x = x.float() + if self.fp16>=3: + x = self.layer4(x) + x = self.bn2(x) + x = torch.flatten(x, 1) + if self.dropout is not None: + x = self.dropout(x) + x = self.fc(x) + x = self.features(x) + return x + + +def _iresnet(arch, block, layers, pretrained, progress, **kwargs): + model = IResNet(block, layers, **kwargs) + if pretrained: + raise ValueError() + return model + + +def iresnet18(pretrained=False, progress=True, **kwargs): + return _iresnet('iresnet18', IBasicBlock, [2, 2, 2, 2], pretrained, + progress, **kwargs) + + +def iresnet34(pretrained=False, progress=True, **kwargs): + return _iresnet('iresnet34', IBasicBlock, [3, 4, 6, 3], pretrained, + progress, **kwargs) + + +def iresnet50(pretrained=False, progress=True, **kwargs): + return _iresnet('iresnet50', IBasicBlock, [3, 4, 14, 3], pretrained, + progress, **kwargs) + + +def iresnet100(pretrained=False, progress=True, **kwargs): + return _iresnet('iresnet100', IBasicBlock, [3, 13, 30, 3], pretrained, + progress, **kwargs) + +def iresnet120(pretrained=False, progress=True, **kwargs): + return _iresnet('iresnet120', IBasicBlock, [3, 16, 37, 3], pretrained, + progress, **kwargs) + +def iresnet160(pretrained=False, progress=True, **kwargs): + return _iresnet('iresnet160', IBasicBlock, [3, 16, 56, 3], pretrained, + progress, **kwargs) + +def iresnet180(pretrained=False, progress=True, **kwargs): + return _iresnet('iresnet180', IBasicBlock, [3, 20, 63, 3], pretrained, + progress, **kwargs) + +def iresnet200(pretrained=False, progress=True, **kwargs): + return _iresnet('iresnet200', IBasicBlock, [6, 26, 60, 6], pretrained, + progress, **kwargs) + +def iresnet247(pretrained=False, progress=True, **kwargs): + return _iresnet('iresnet247', IBasicBlock, [3, 36, 80, 4], pretrained, + progress, **kwargs) + +def iresnet269(pretrained=False, progress=True, **kwargs): + return _iresnet('iresnet269', IBasicBlock, [4, 46, 80, 4], pretrained, + progress, **kwargs) + +def iresnet300(pretrained=False, progress=True, **kwargs): + return _iresnet('iresnet300', IBasicBlock, [4, 46, 95, 4], pretrained, + progress, **kwargs) + + +def get_model(name, **kwargs): + if name == "r18": + return iresnet18(False, **kwargs) + elif name == "r34": + return iresnet34(False, **kwargs) + elif name == "r50": + return iresnet50(False, **kwargs) + elif name == "r100": + return iresnet100(False, **kwargs) + elif name == "r200": + return iresnet200(False, **kwargs) + elif name == "r2060": + from .iresnet2060 import iresnet2060 + return iresnet2060(False, **kwargs) + else: + raise ValueError() + diff --git a/reconstruction/jmlr/backbones/network.py b/reconstruction/jmlr/backbones/network.py new file mode 100644 index 0000000..9d759a8 --- /dev/null +++ b/reconstruction/jmlr/backbones/network.py @@ -0,0 +1,238 @@ +import os +import time +import timm +import glob +import numpy as np +import os.path as osp + +import torch +import torch.distributed as dist +from torch import nn +import torch.nn.functional as F +from .iresnet import get_model as arcface_get_model + + +def kaiming_leaky_init(m): + classname = m.__class__.__name__ + if classname.find('Linear') != -1: + torch.nn.init.kaiming_normal_(m.weight, a=0.2, mode='fan_in', nonlinearity='leaky_relu') + +class CustomMappingNetwork(nn.Module): + def __init__(self, z_dim, map_hidden_dim, map_output_dim): + super().__init__() + + + + self.network = nn.Sequential(nn.Linear(z_dim, map_hidden_dim), + nn.LeakyReLU(0.2, inplace=True), + + nn.Linear(map_hidden_dim, map_hidden_dim), + nn.LeakyReLU(0.2, inplace=True), + + nn.Linear(map_hidden_dim, map_hidden_dim), + nn.LeakyReLU(0.2, inplace=True), + + nn.Linear(map_hidden_dim, map_output_dim)) + + self.network.apply(kaiming_leaky_init) + with torch.no_grad(): + self.network[-1].weight *= 0.25 + + def forward(self, z): + frequencies_offsets = self.network(z) + frequencies = frequencies_offsets[..., :frequencies_offsets.shape[-1]//2] + phase_shifts = frequencies_offsets[..., frequencies_offsets.shape[-1]//2:] + + return frequencies, phase_shifts + +class FiLMLayer(nn.Module): + def __init__(self, input_dim, hidden_dim): + super().__init__() + self.layer = nn.Linear(input_dim, hidden_dim) + + def forward(self, x, freq, phase_shift): + x = self.layer(x) + return torch.sin(freq * x + phase_shift) + +class InstanceNorm(nn.Module): + def __init__(self, epsilon=1e-8): + """ + @notice: avoid in-place ops. + https://discuss.pytorch.org/t/encounter-the-runtimeerror-one-of-the-variables-needed-for-gradient-computation-has-been-modified-by-an-inplace-operation/836/3 + """ + super(InstanceNorm, self).__init__() + self.epsilon = epsilon + + def forward(self, x): + x = x - torch.mean(x, (2, 3), True) + tmp = torch.mul(x, x) # or x ** 2 + tmp = torch.rsqrt(torch.mean(tmp, (2, 3), True) + self.epsilon) + return x * tmp + +class ApplyStyle(nn.Module): + def __init__(self, latent_size, channels): + super(ApplyStyle, self).__init__() + self.linear = nn.Linear(latent_size, channels * 2) + + def forward(self, x, latent): + style = self.linear(latent).unsqueeze(2).unsqueeze(3) #B, 2*c, 1, 1 + gamma, beta = style.chunk(2, 1) + x = gamma * x + beta + return x + +class ResnetBlock_Adain(nn.Module): + def __init__(self, dim, latent_size, padding_type='reflect', activation=nn.ReLU(True)): + super(ResnetBlock_Adain, self).__init__() + + p = 0 + conv1 = [] + if padding_type == 'reflect': + conv1 += [nn.ReflectionPad2d(1)] + elif padding_type == 'replicate': + conv1 += [nn.ReplicationPad2d(1)] + elif padding_type == 'zero': + p = 1 + else: + raise NotImplementedError('padding [%s] is not implemented' % padding_type) + conv1 += [nn.Conv2d(dim, dim, kernel_size=3, padding = p), InstanceNorm()] + self.conv1 = nn.Sequential(*conv1) + self.style1 = ApplyStyle(latent_size, dim) + self.act1 = activation + + p = 0 + conv2 = [] + if padding_type == 'reflect': + conv2 += [nn.ReflectionPad2d(1)] + elif padding_type == 'replicate': + conv2 += [nn.ReplicationPad2d(1)] + elif padding_type == 'zero': + p = 1 + else: + raise NotImplementedError('padding [%s] is not implemented' % padding_type) + conv2 += [nn.Conv2d(dim, dim, kernel_size=3, padding=p), InstanceNorm()] + self.conv2 = nn.Sequential(*conv2) + self.style2 = ApplyStyle(latent_size, dim) + + + def forward(self, x, dlatents_in_slice): + y = self.conv1(x) + y = self.style1(y, dlatents_in_slice) + y = self.act1(y) + y = self.conv2(y) + y = self.style2(y, dlatents_in_slice) + out = x + y + return out + +class OneNetwork(nn.Module): + def __init__(self, cfg): + super(OneNetwork, self).__init__() + self.num_verts = cfg.num_verts + self.input_size = cfg.input_size + kwargs = {} + num_classes = self.num_verts*5 + if cfg.task==1: + num_classes = self.num_verts*3 + elif cfg.task==2: + num_classes = 6 + elif cfg.task==3: + num_classes = self.num_verts*2 + if cfg.network.startswith('resnet'): + kwargs['base_width'] = int(64*cfg.width_mult) + p_num_classes = num_classes + if cfg.no_gap: + p_num_classes = 0 + kwargs['global_pool'] = None + elif cfg.use_arcface: + p_num_classes = 0 + kwargs['global_pool'] = None + if cfg.network=='resnet_jmlr': + from .resnet import resnet_jmlr + self.net = resnet_jmlr(num_classes = p_num_classes, **kwargs) + else: + self.net = timm.create_model(cfg.network, num_classes = p_num_classes, **kwargs) + + if cfg.no_gap: + in_channel = self.net.num_features + feat_hw = (self.input_size//32)**2 + mid_channel = 128 + self.no_gap_output = nn.Sequential(*[ + nn.BatchNorm2d(in_channel), + nn.Conv2d(in_channel, mid_channel, 1, stride=1, padding=0, bias=False), + nn.ReLU(), + nn.Flatten(1), + nn.Linear(mid_channel*feat_hw, num_classes)]) + + self.no_gap = cfg.no_gap + self.use_arcface = cfg.use_arcface + if self.use_arcface: + self.neta = arcface_get_model(cfg.arcface_model, input_size=cfg.arcface_input_size) + self.neta.load_state_dict(torch.load(cfg.arcface_ckpt, map_location=torch.device('cpu'))) + self.neta.eval() + self.neta.requires_grad_(False) + input_dim = 512 #resnet34d + z_dim = 512 #arcface_dim + hidden_dim = 256 + self.pool = nn.AdaptiveAvgPool2d(1) + self.flatten = nn.Flatten(1) + mlp_act = nn.LeakyReLU + + self.mlp = nn.Sequential(*[ + nn.Linear(z_dim, hidden_dim), + mlp_act(), + nn.Linear(hidden_dim, hidden_dim), + mlp_act(), + nn.Linear(hidden_dim, hidden_dim), + mlp_act(), + nn.Linear(hidden_dim, hidden_dim), + mlp_act(), + nn.Linear(hidden_dim, hidden_dim), + mlp_act(), + nn.Linear(hidden_dim, hidden_dim), + mlp_act(), + nn.Linear(hidden_dim, input_dim), + ]) + style_blocks = [] + for i in range(3): + style_blocks += [ResnetBlock_Adain(input_dim, latent_size=z_dim)] + self.style_blocks = nn.Sequential(*style_blocks) + self.branch2d = nn.Sequential(*[ + nn.Conv2d(input_dim, input_dim, 3, stride=1, padding=1, bias=False), + nn.BatchNorm2d(input_dim), + nn.ReLU(), + ]) + self.fc = nn.Linear(input_dim*2, num_classes) + + + def forward(self, x): + if self.use_arcface: + conv_feat = self.net.forward_features(x) + #input = self.flatten(self.pool(conv_feat)) + xa = F.interpolate(x, [144, 144], mode='bilinear', align_corners=False) + xa = xa[:,:,8:120,16:128] + z = self.neta(xa) + z = self.mlp(z) + + c = conv_feat + for i in range(len(self.style_blocks)): + c = self.style_blocks[i](c, z) + feat3 = c + feat2 = self.branch2d(conv_feat) + conv_feat = torch.cat([feat3, feat2], dim=1) + feat = self.flatten(self.pool(conv_feat)) + pred = self.fc(feat) + + elif self.no_gap: + y = self.net.forward_features(x) + pred = self.no_gap_output(y) + else: + pred = self.net(x) + return pred + +def get_network(cfg): + if cfg.use_onenetwork: + net = OneNetwork(cfg) + else: + net = timm.create_model(cfg.network, num_classes = 1220*5) + return net + + diff --git a/reconstruction/jmlr/backbones/resnet.py b/reconstruction/jmlr/backbones/resnet.py new file mode 100644 index 0000000..6dc9b9f --- /dev/null +++ b/reconstruction/jmlr/backbones/resnet.py @@ -0,0 +1,453 @@ +"""PyTorch ResNet + +This started as a copy of https://github.com/pytorch/vision 'resnet.py' (BSD-3-Clause) with +additional dropout and dynamic global avg/max pool. + +ResNeXt, SE-ResNeXt, SENet, and MXNet Gluon stem/downsample variants, tiered stems added by Ross Wightman + +Copyright 2019, Ross Wightman +""" +import math +from functools import partial + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from timm.models.helpers import build_model_with_cfg +from timm.models.layers import DropBlock2d, DropPath, AvgPool2dSame, BlurPool2d, GroupNorm, create_attn, get_attn, create_classifier + + + +def get_padding(kernel_size, stride, dilation=1): + padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2 + return padding + + +def create_aa(aa_layer, channels, stride=2, enable=True): + if not aa_layer or not enable: + return nn.Identity() + return aa_layer(stride) if issubclass(aa_layer, nn.AvgPool2d) else aa_layer(channels=channels, stride=stride) + + +class BasicBlock(nn.Module): + expansion = 1 + + def __init__( + self, inplanes, planes, stride=1, downsample=None, cardinality=1, base_width=64, + reduce_first=1, dilation=1, first_dilation=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, + attn_layer=None, aa_layer=None, drop_block=None, drop_path=None): + super(BasicBlock, self).__init__() + + assert cardinality == 1, 'BasicBlock only supports cardinality of 1' + assert base_width == 64, 'BasicBlock does not support changing base width' + first_planes = planes // reduce_first + outplanes = planes * self.expansion + first_dilation = first_dilation or dilation + use_aa = aa_layer is not None and (stride == 2 or first_dilation != dilation) + + self.conv1 = nn.Conv2d( + inplanes, first_planes, kernel_size=3, stride=1 if use_aa else stride, padding=first_dilation, + dilation=first_dilation, bias=False) + self.bn1 = norm_layer(first_planes) + self.drop_block = drop_block() if drop_block is not None else nn.Identity() + self.act1 = act_layer(inplace=True) + self.aa = create_aa(aa_layer, channels=first_planes, stride=stride, enable=use_aa) + + self.conv2 = nn.Conv2d( + first_planes, outplanes, kernel_size=3, padding=dilation, dilation=dilation, bias=False) + self.bn2 = norm_layer(outplanes) + + self.se = create_attn(attn_layer, outplanes) + + self.act2 = act_layer(inplace=True) + self.downsample = downsample + self.stride = stride + self.dilation = dilation + self.drop_path = drop_path + + def zero_init_last(self): + nn.init.zeros_(self.bn2.weight) + + def forward(self, x): + shortcut = x + + x = self.conv1(x) + x = self.bn1(x) + x = self.drop_block(x) + x = self.act1(x) + x = self.aa(x) + + x = self.conv2(x) + x = self.bn2(x) + + if self.se is not None: + x = self.se(x) + + if self.drop_path is not None: + x = self.drop_path(x) + + if self.downsample is not None: + shortcut = self.downsample(shortcut) + x += shortcut + x = self.act2(x) + + return x + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__( + self, inplanes, planes, stride=1, downsample=None, cardinality=1, base_width=64, + reduce_first=1, dilation=1, first_dilation=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, + attn_layer=None, aa_layer=None, drop_block=None, drop_path=None): + super(Bottleneck, self).__init__() + + width = int(math.floor(planes * (base_width / 64)) * cardinality) + first_planes = width // reduce_first + outplanes = planes * self.expansion + first_dilation = first_dilation or dilation + use_aa = aa_layer is not None and (stride == 2 or first_dilation != dilation) + + self.conv1 = nn.Conv2d(inplanes, first_planes, kernel_size=1, bias=False) + self.bn1 = norm_layer(first_planes) + self.act1 = act_layer(inplace=True) + + self.conv2 = nn.Conv2d( + first_planes, width, kernel_size=3, stride=1 if use_aa else stride, + padding=first_dilation, dilation=first_dilation, groups=cardinality, bias=False) + self.bn2 = norm_layer(width) + self.drop_block = drop_block() if drop_block is not None else nn.Identity() + self.act2 = act_layer(inplace=True) + self.aa = create_aa(aa_layer, channels=width, stride=stride, enable=use_aa) + + self.conv3 = nn.Conv2d(width, outplanes, kernel_size=1, bias=False) + self.bn3 = norm_layer(outplanes) + + self.se = create_attn(attn_layer, outplanes) + + self.act3 = act_layer(inplace=True) + self.downsample = downsample + self.stride = stride + self.dilation = dilation + self.drop_path = drop_path + + def zero_init_last(self): + nn.init.zeros_(self.bn3.weight) + + def forward(self, x): + shortcut = x + + x = self.conv1(x) + x = self.bn1(x) + x = self.act1(x) + + x = self.conv2(x) + x = self.bn2(x) + x = self.drop_block(x) + x = self.act2(x) + x = self.aa(x) + + x = self.conv3(x) + x = self.bn3(x) + + if self.se is not None: + x = self.se(x) + + if self.drop_path is not None: + x = self.drop_path(x) + + if self.downsample is not None: + shortcut = self.downsample(shortcut) + x += shortcut + x = self.act3(x) + + return x + + +def downsample_conv( + in_channels, out_channels, kernel_size, stride=1, dilation=1, first_dilation=None, norm_layer=None): + norm_layer = norm_layer or nn.BatchNorm2d + kernel_size = 1 if stride == 1 and dilation == 1 else kernel_size + first_dilation = (first_dilation or dilation) if kernel_size > 1 else 1 + p = get_padding(kernel_size, stride, first_dilation) + + return nn.Sequential(*[ + nn.Conv2d( + in_channels, out_channels, kernel_size, stride=stride, padding=p, dilation=first_dilation, bias=False), + norm_layer(out_channels) + ]) + + +def downsample_avg( + in_channels, out_channels, kernel_size, stride=1, dilation=1, first_dilation=None, norm_layer=None): + norm_layer = norm_layer or nn.BatchNorm2d + avg_stride = stride if dilation == 1 else 1 + if stride == 1 and dilation == 1: + pool = nn.Identity() + else: + avg_pool_fn = AvgPool2dSame if avg_stride == 1 and dilation > 1 else nn.AvgPool2d + pool = avg_pool_fn(2, avg_stride, ceil_mode=True, count_include_pad=False) + + return nn.Sequential(*[ + pool, + nn.Conv2d(in_channels, out_channels, 1, stride=1, padding=0, bias=False), + norm_layer(out_channels) + ]) + + +def drop_blocks(drop_prob=0.): + return [ + None, None, + partial(DropBlock2d, drop_prob=drop_prob, block_size=5, gamma_scale=0.25) if drop_prob else None, + partial(DropBlock2d, drop_prob=drop_prob, block_size=3, gamma_scale=1.00) if drop_prob else None] + + +def make_blocks( + block_fn, channels, block_repeats, inplanes, reduce_first=1, output_stride=32, + down_kernel_size=1, avg_down=False, drop_block_rate=0., drop_path_rate=0., **kwargs): + stages = [] + feature_info = [] + net_num_blocks = sum(block_repeats) + net_block_idx = 0 + net_stride = 4 + dilation = prev_dilation = 1 + for stage_idx, (planes, num_blocks, db) in enumerate(zip(channels, block_repeats, drop_blocks(drop_block_rate))): + stage_name = f'layer{stage_idx + 1}' # never liked this name, but weight compat requires it + stride = 1 if stage_idx == 0 else 2 + if net_stride >= output_stride: + dilation *= stride + stride = 1 + else: + net_stride *= stride + + downsample = None + if stride != 1 or inplanes != planes * block_fn.expansion: + down_kwargs = dict( + in_channels=inplanes, out_channels=planes * block_fn.expansion, kernel_size=down_kernel_size, + stride=stride, dilation=dilation, first_dilation=prev_dilation, norm_layer=kwargs.get('norm_layer')) + downsample = downsample_avg(**down_kwargs) if avg_down else downsample_conv(**down_kwargs) + + block_kwargs = dict(reduce_first=reduce_first, dilation=dilation, drop_block=db, **kwargs) + blocks = [] + for block_idx in range(num_blocks): + downsample = downsample if block_idx == 0 else None + stride = stride if block_idx == 0 else 1 + block_dpr = drop_path_rate * net_block_idx / (net_num_blocks - 1) # stochastic depth linear decay rule + blocks.append(block_fn( + inplanes, planes, stride, downsample, first_dilation=prev_dilation, + drop_path=DropPath(block_dpr) if block_dpr > 0. else None, **block_kwargs)) + prev_dilation = dilation + inplanes = planes * block_fn.expansion + net_block_idx += 1 + + stages.append((stage_name, nn.Sequential(*blocks))) + feature_info.append(dict(num_chs=inplanes, reduction=net_stride, module=stage_name)) + + return stages, feature_info + + +class ResNet(nn.Module): + """ResNet / ResNeXt / SE-ResNeXt / SE-Net + + This class implements all variants of ResNet, ResNeXt, SE-ResNeXt, and SENet that + * have > 1 stride in the 3x3 conv layer of bottleneck + * have conv-bn-act ordering + + This ResNet impl supports a number of stem and downsample options based on the v1c, v1d, v1e, and v1s + variants included in the MXNet Gluon ResNetV1b model. The C and D variants are also discussed in the + 'Bag of Tricks' paper: https://arxiv.org/pdf/1812.01187. The B variant is equivalent to torchvision default. + + ResNet variants (the same modifications can be used in SE/ResNeXt models as well): + * normal, b - 7x7 stem, stem_width = 64, same as torchvision ResNet, NVIDIA ResNet 'v1.5', Gluon v1b + * c - 3 layer deep 3x3 stem, stem_width = 32 (32, 32, 64) + * d - 3 layer deep 3x3 stem, stem_width = 32 (32, 32, 64), average pool in downsample + * e - 3 layer deep 3x3 stem, stem_width = 64 (64, 64, 128), average pool in downsample + * s - 3 layer deep 3x3 stem, stem_width = 64 (64, 64, 128) + * t - 3 layer deep 3x3 stem, stem width = 32 (24, 48, 64), average pool in downsample + * tn - 3 layer deep 3x3 stem, stem width = 32 (24, 32, 64), average pool in downsample + + ResNeXt + * normal - 7x7 stem, stem_width = 64, standard cardinality and base widths + * same c,d, e, s variants as ResNet can be enabled + + SE-ResNeXt + * normal - 7x7 stem, stem_width = 64 + * same c, d, e, s variants as ResNet can be enabled + + SENet-154 - 3 layer deep 3x3 stem (same as v1c-v1s), stem_width = 64, cardinality=64, + reduction by 2 on width of first bottleneck convolution, 3x3 downsample convs after first block + + Parameters + ---------- + block : Block, class for the residual block. Options are BasicBlockGl, BottleneckGl. + layers : list of int, number of layers in each block + num_classes : int, default 1000, number of classification classes. + in_chans : int, default 3, number of input (color) channels. + output_stride : int, default 32, output stride of the network, 32, 16, or 8. + global_pool : str, Global pooling type. One of 'avg', 'max', 'avgmax', 'catavgmax' + cardinality : int, default 1, number of convolution groups for 3x3 conv in Bottleneck. + base_width : int, default 64, factor determining bottleneck channels. `planes * base_width / 64 * cardinality` + stem_width : int, default 64, number of channels in stem convolutions + stem_type : str, default '' + The type of stem: + * '', default - a single 7x7 conv with a width of stem_width + * 'deep' - three 3x3 convolution layers of widths stem_width, stem_width, stem_width * 2 + * 'deep_tiered' - three 3x3 conv layers of widths stem_width//4 * 3, stem_width, stem_width * 2 + block_reduce_first : int, default 1 + Reduction factor for first convolution output width of residual blocks, 1 for all archs except senets, where 2 + down_kernel_size : int, default 1, kernel size of residual block downsample path, 1x1 for most, 3x3 for senets + avg_down : bool, default False, use average pooling for projection skip connection between stages/downsample. + act_layer : nn.Module, activation layer + norm_layer : nn.Module, normalization layer + aa_layer : nn.Module, anti-aliasing layer + drop_rate : float, default 0. Dropout probability before classifier, for training + """ + + def __init__( + self, block, layers, num_classes=1000, in_chans=3, output_stride=32, global_pool='avg', + cardinality=1, base_width=64, stem_width=64, stem_type='', replace_stem_pool=False, block_reduce_first=1, + down_kernel_size=1, avg_down=False, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, aa_layer=None, + drop_rate=0.0, drop_path_rate=0., drop_block_rate=0., zero_init_last=True, block_args=None, channels=[64, 128, 256, 512]): + super(ResNet, self).__init__() + block_args = block_args or dict() + assert output_stride in (8, 16, 32) + self.num_classes = num_classes + self.drop_rate = drop_rate + self.grad_checkpointing = False + + # Stem + deep_stem = 'deep' in stem_type + inplanes = stem_width * 2 if deep_stem else 64 + if deep_stem: + stem_chs = (stem_width, stem_width) + if 'tiered' in stem_type: + stem_chs = (3 * (stem_width // 4), stem_width) + self.conv1 = nn.Sequential(*[ + nn.Conv2d(in_chans, stem_chs[0], 3, stride=2, padding=1, bias=False), + norm_layer(stem_chs[0]), + act_layer(inplace=True), + nn.Conv2d(stem_chs[0], stem_chs[1], 3, stride=1, padding=1, bias=False), + norm_layer(stem_chs[1]), + act_layer(inplace=True), + nn.Conv2d(stem_chs[1], inplanes, 3, stride=1, padding=1, bias=False)]) + else: + self.conv1 = nn.Conv2d(in_chans, inplanes, kernel_size=7, stride=2, padding=3, bias=False) + self.bn1 = norm_layer(inplanes) + self.act1 = act_layer(inplace=True) + self.feature_info = [dict(num_chs=inplanes, reduction=2, module='act1')] + + # Stem pooling. The name 'maxpool' remains for weight compatibility. + if replace_stem_pool: + self.maxpool = nn.Sequential(*filter(None, [ + nn.Conv2d(inplanes, inplanes, 3, stride=1 if aa_layer else 2, padding=1, bias=False), + create_aa(aa_layer, channels=inplanes, stride=2) if aa_layer is not None else None, + norm_layer(inplanes), + act_layer(inplace=True) + ])) + else: + if aa_layer is not None: + if issubclass(aa_layer, nn.AvgPool2d): + self.maxpool = aa_layer(2) + else: + self.maxpool = nn.Sequential(*[ + nn.MaxPool2d(kernel_size=3, stride=1, padding=1), + aa_layer(channels=inplanes, stride=2)]) + else: + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + + # Feature Blocks + #channels = [64, 128, 256, 512] + stage_modules, stage_feature_info = make_blocks( + block, channels, layers, inplanes, cardinality=cardinality, base_width=base_width, + output_stride=output_stride, reduce_first=block_reduce_first, avg_down=avg_down, + down_kernel_size=down_kernel_size, act_layer=act_layer, norm_layer=norm_layer, aa_layer=aa_layer, + drop_block_rate=drop_block_rate, drop_path_rate=drop_path_rate, **block_args) + for stage in stage_modules: + self.add_module(*stage) # layer1, layer2, etc + self.feature_info.extend(stage_feature_info) + + # Head (Pooling and Classifier) + self.num_features = 512 * block.expansion + self.global_pool, self.fc = create_classifier(self.num_features, self.num_classes, pool_type=global_pool) + + self.init_weights(zero_init_last=zero_init_last) + + @torch.jit.ignore + def init_weights(self, zero_init_last=True): + for n, m in self.named_modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, nn.BatchNorm2d): + nn.init.ones_(m.weight) + nn.init.zeros_(m.bias) + if zero_init_last: + for m in self.modules(): + if hasattr(m, 'zero_init_last'): + m.zero_init_last() + + @torch.jit.ignore + def group_matcher(self, coarse=False): + matcher = dict(stem=r'^conv1|bn1|maxpool', blocks=r'^layer(\d+)' if coarse else r'^layer(\d+)\.(\d+)') + return matcher + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + self.grad_checkpointing = enable + + @torch.jit.ignore + def get_classifier(self, name_only=False): + return 'fc' if name_only else self.fc + + def reset_classifier(self, num_classes, global_pool='avg'): + self.num_classes = num_classes + self.global_pool, self.fc = create_classifier(self.num_features, self.num_classes, pool_type=global_pool) + + def forward_features(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = self.act1(x) + x = self.maxpool(x) + + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint_seq([self.layer1, self.layer2, self.layer3, self.layer4], x, flatten=True) + else: + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + return x + + def forward_head(self, x, pre_logits: bool = False): + x = self.global_pool(x) + if self.drop_rate: + x = F.dropout(x, p=float(self.drop_rate), training=self.training) + return x if pre_logits else self.fc(x) + + def forward(self, x): + x = self.forward_features(x) + x = self.forward_head(x) + return x + + +#def _create_resnet(variant, pretrained=False, **kwargs): +# return build_model_with_cfg(ResNet, variant, pretrained, **kwargs) + + + +def resnet34d(pretrained=False, **kwargs): + """Constructs a ResNet-34-D model. + """ + model_args = dict( + block=BasicBlock, layers=[3, 4, 6, 3], stem_width=32, stem_type='deep', avg_down=True, **kwargs) + return ResNet(**model_args) + + +def resnet_jmlr(pretrained=False, **kwargs): + model_args = dict( + block=BasicBlock, layers=[5, 3, 4, 2], stem_width=32, stem_type='deep', avg_down=True, channels=[64,160,272,512], **kwargs) + #return _create_resnet('resnet34d', pretrained, **model_args) + return ResNet(**model_args) + diff --git a/reconstruction/jmlr/configs/base.py b/reconstruction/jmlr/configs/base.py new file mode 100644 index 0000000..c5d61d3 --- /dev/null +++ b/reconstruction/jmlr/configs/base.py @@ -0,0 +1,101 @@ +from easydict import EasyDict as edict +import numpy as np + +config = edict() +config.embedding_size = 512 +config.sample_rate = 1 +config.fp16 = 0 +config.tf32 = False +config.backbone_wd = None +config.batch_size = 128 +config.clip_grad = None +config.dropout = 0.0 +#config.warmup_epoch = -1 +config.loss = 'cosface' +config.margin = 0.4 +config.hard_margin = False +config.network = 'r50' +config.prelu = True +config.stem_type = '' +config.dropblock = 0.0 +config.output = None +config.input_size = 112 +config.width_mult = 1.0 +config.kaiming_init = True +config.use_se = False +config.aug_modes = [] +config.checkpoint_segments = [1, 1, 1, 1] + +config.sampling_id = True +config.id_sampling_ratio = None +metric_loss = edict() +metric_loss.enable = False +metric_loss.lambda_n = 0.0 +metric_loss.lambda_c = 0.0 +metric_loss.lambda_t = 0.0 +metric_loss.margin_c = 1.0 +metric_loss.margin_t = 1.0 +metric_loss.margin_n = 0.4 +config.metric_loss = metric_loss + +config.opt = 'sgd' +config.lr = 0.1 # when batch size is 512 +config.momentum = 0.9 +config.weight_decay = 5e-4 +config.fc_mom = 0.9 + +config.warmup_epochs = 0 +config.max_warmup_steps = 6000 +config.num_epochs = 24 + + +config.resume = False +config.resume_path = None +config.resume_from = None + +config.save_every_epochs = True + +config.lr_func = None +config.lr_epochs = None +config.save_pfc = False +config.save_onnx = False +config.save_opt = False + +config.label_6dof_mean = np.array([-0.018197, -0.017891, 0.025348, -0.005368, 0.001176, -0.532206], dtype=np.float32) # mean of pitch, yaw, roll, tx, ty, tz +config.label_6dof_std = np.array([0.314015, 0.271809, 0.081881, 0.022173, 0.048839, 0.065444], dtype=np.float32) # std of pitch, yaw, roll, tx, ty, tz + +config.num_verts = 1220 +config.flipindex_file = 'cache_align/flip_index.npy' +config.enable_flip = True +config.verts3d_central_index = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 956, 975, 1022, 1041, 1047, 1048, 1049] + +config.task = 0 +config.ckpt = None +config.loss_hard = False +config.sampling_hard = False +config.loss_pip = False +config.net_stride = 32 +config.loss_bone3d = False +config.loss_bone2d = False + +config.lossw_verts3d = 8.0 +config.lossw_verts2d = 16.0 +config.lossw_bone3d = 10.0 +config.lossw_bone2d = 10.0 +config.lossw_project = 10.0 + +config.align_face = False +config.no_gap = False + +config.use_trainval = False + +config.project_loss = False + +config.use_onenetwork = True + +config.use_rtloss = False + + +config.use_arcface = False + + diff --git a/reconstruction/jmlr/configs/s1.py b/reconstruction/jmlr/configs/s1.py new file mode 100644 index 0000000..95d7738 --- /dev/null +++ b/reconstruction/jmlr/configs/s1.py @@ -0,0 +1,54 @@ +from easydict import EasyDict as edict + +config = edict() + +config.dataset = "wcpa" +config.root_dir = '/data/insightface/wcpa' +config.cache_dir = './cache_align' +#config.num_classes = 617970 +#config.num_classes = 2000000 +#config.num_classes = 80000000 +#config.val_targets = ["lfw", "cfp_fp", "agedb_30"] +#config.val_targets = ["lfw"] +#config.val_targets = [] +config.verbose = 20000 + +#config.network = 'resnet34d' +config.network = 'resnet_jmlr' +config.input_size = 256 +#config.width_mult = 1.0 +#config.dropout = 0.0 +#config.loss = 'cosface' +#config.embedding_size = 512 +#config.sample_rate = 0.2 +config.fp16 = 0 +config.tf32 = True +config.weight_decay = 5e-4 +config.batch_size = 64 +config.lr = 0.1 # lr when batch size is 512 + +config.aug_modes = ['1'] + +config.num_epochs = 40 +config.warmup_epochs = 1 +config.max_warmup_steps = 1000 + +#def lr_step_func(epoch): +# return ((epoch + 1) / (4 + 1)) ** 2 if epoch < -1 else 0.1 ** len( +# [m for m in [20, 30, 38] if m - 1 <= epoch]) +#config.lr_func = lr_step_func + +config.task = 0 +config.save_every_epochs = False + + +config.lossw_verts3d = 16.0 + +config.align_face = True + +config.use_trainval = True +#config.use_rtloss = True + +config.loss_bone3d = True +config.lossw_bone3d = 2.0 + diff --git a/reconstruction/jmlr/dataset.py b/reconstruction/jmlr/dataset.py new file mode 100644 index 0000000..f8ac861 --- /dev/null +++ b/reconstruction/jmlr/dataset.py @@ -0,0 +1,700 @@ +import numbers +import os +import os.path as osp +import pickle +import queue as Queue +import threading +import logging +import numbers +import math +import pandas as pd +from scipy.spatial.transform import Rotation + +import mxnet as mx +from pathlib import Path +import numpy as np +import torch +from torch.utils.data import DataLoader, Dataset +from skimage import transform as sktrans +import cv2 +import albumentations as A +from albumentations.pytorch import ToTensorV2 +from augs import * + + +def Rt26dof(R_t, degrees=False): + yaw_gt, pitch_gt, roll_gt = Rotation.from_matrix(R_t[:3, :3].T).as_euler('yxz', degrees=degrees) + label_euler = np.array([pitch_gt, yaw_gt, roll_gt]) + label_translation = R_t[3, :3] + label_6dof = np.concatenate([label_euler, label_translation]) + return label_6dof + + +def gen_target_pip(target, target_map, target_local_x, target_local_y): + map_channel, map_height, map_width = target_map.shape + target = target.reshape(-1, 2) + assert map_channel == target.shape[0] + + for i in range(map_channel): + mu_x = int(math.floor(target[i][0] * map_width)) + mu_y = int(math.floor(target[i][1] * map_height)) + mu_x = max(0, mu_x) + mu_y = max(0, mu_y) + mu_x = min(mu_x, map_width-1) + mu_y = min(mu_y, map_height-1) + target_map[i, mu_y, mu_x] = 1 + shift_x = target[i][0] * map_width - mu_x + shift_y = target[i][1] * map_height - mu_y + target_local_x[i, mu_y, mu_x] = shift_x + target_local_y[i, mu_y, mu_x] = shift_y + + + return target_map, target_local_x, target_local_y + +def get_tris(cfg): + import trimesh + data_root = Path(cfg.root_dir) + obj_path = data_root / 'resources/example.obj' + mesh = trimesh.load(obj_path, process=False) + verts_template = np.array(mesh.vertices, dtype=np.float32) + tris = np.array(mesh.faces, dtype=np.int32) + #print(verts_template.shape, tris.shape) + return tris + + +class BackgroundGenerator(threading.Thread): + def __init__(self, generator, local_rank, max_prefetch=6): + super(BackgroundGenerator, self).__init__() + self.queue = Queue.Queue(max_prefetch) + self.generator = generator + self.local_rank = local_rank + self.daemon = True + self.start() + + def run(self): + torch.cuda.set_device(self.local_rank) + for item in self.generator: + self.queue.put(item) + self.queue.put(None) + + def next(self): + next_item = self.queue.get() + if next_item is None: + raise StopIteration + return next_item + + def __next__(self): + return self.next() + + def __iter__(self): + return self + + +class DataLoaderX(DataLoader): + def __init__(self, local_rank, **kwargs): + super(DataLoaderX, self).__init__(**kwargs) + self.stream = torch.cuda.Stream(local_rank) + self.local_rank = local_rank + + def __iter__(self): + self.iter = super(DataLoaderX, self).__iter__() + self.iter = BackgroundGenerator(self.iter, self.local_rank) + self.preload() + return self + + def preload(self): + self.batch = next(self.iter, None) + if self.batch is None: + return None + with torch.cuda.stream(self.stream): + for k in range(len(self.batch)): + self.batch[k] = self.batch[k].to(device=self.local_rank, + non_blocking=True) + + def __next__(self): + torch.cuda.current_stream().wait_stream(self.stream) + batch = self.batch + if batch is None: + raise StopIteration + self.preload() + return batch + + +class FaceDataset(Dataset): + def __init__(self, cfg, is_train=True, is_test=False, local_rank=0): + super(FaceDataset, self).__init__() + + + self.data_root = Path(cfg.root_dir) + self.input_size = cfg.input_size + self.transform = get_aug_transform(cfg) + self.local_rank = local_rank + self.is_test = is_test + txt_path = self.data_root / 'resources/projection_matrix.txt' + self.M_proj = np.loadtxt(txt_path, dtype=np.float32) + if is_test: + data_root = Path(cfg.root_dir) + csv_path = data_root / 'list/WCPA_track2_test.csv' + self.df = pd.read_csv(csv_path, dtype={'subject_id': str, 'facial_action': str, 'img_id': str}) + else: + if is_train: + self.df = pd.read_csv(osp.join(cfg.cache_dir, 'train_list.csv'), dtype={'subject_id': str, 'facial_action': str, 'img_id': str}) + else: + self.df = pd.read_csv(osp.join(cfg.cache_dir, 'val_list.csv'), dtype={'subject_id': str, 'facial_action': str, 'img_id': str}) + self.label_6dof_mean = [-0.018197, -0.017891, 0.025348, -0.005368, 0.001176, -0.532206] # mean of pitch, yaw, roll, tx, ty, tz + self.label_6dof_std = [0.314015, 0.271809, 0.081881, 0.022173, 0.048839, 0.065444] # std of pitch, yaw, roll, tx, ty, tz + self.align_face = cfg.align_face + if not self.align_face: + self.dst_pts = np.float32([ + [0, 0], + [0, cfg.input_size- 1], + [cfg.input_size- 1, 0] + ]) + else: + dst_pts = np.array([ + [38.2946, 51.6963], + [73.5318, 51.5014], + [56.0252, 71.7366], + [41.5493, 92.3655], + [70.7299, 92.2041] ], dtype=np.float32 ) + + new_size = 144 + dst_pts[:,0] += ((new_size-112)//2) + dst_pts[:,1] += 8 + dst_pts[:,:] *= (self.input_size/float(new_size)) + self.dst_pts = dst_pts + + if local_rank==0: + logging.info('data_transform_list:%s'%self.transform) + logging.info('len:%d'%len(self.df)) + self.is_test_aug = False + + def set_test_aug(self): + if not self.is_test_aug: + from easydict import EasyDict as edict + cfg = edict() + cfg.aug_modes = ['test-aug'] + cfg.input_size = self.input_size + cfg.task = 0 + self.transform = get_aug_transform(cfg) + self.is_test_aug = True + + def get_names(self, index): + subject_id = self.df['subject_id'][index] + facial_action = self.df['facial_action'][index] + img_id = self.df['img_id'][index] + return subject_id, facial_action, img_id + + def __getitem__(self, index): + subject_id = self.df['subject_id'][index] + facial_action = self.df['facial_action'][index] + img_id = self.df['img_id'][index] + + img_path = self.data_root / 'image' / subject_id / facial_action / f'{img_id}_ar.jpg' + npz_path = self.data_root / 'info' / subject_id / facial_action / f'{img_id}_info.npz' + txt_path = self.data_root / '68landmarks' / subject_id / facial_action / f'{img_id}_68landmarks.txt' + #if not osp.exists(img_path): + # continue + + #print(img_path) + img_raw = cv2.imread(str(img_path)) + #if img_raw is None: + # print('XXX ERR:', img_path) + img_raw = cv2.cvtColor(img_raw, cv2.COLOR_BGR2RGB) + #print(img_raw.shape) + img_h, img_w, _ = img_raw.shape + pts68 = np.loadtxt(txt_path, dtype=np.int32) + + x_min, y_min = pts68.min(axis=0) + x_max, y_max = pts68.max(axis=0) + x_center = (x_min + x_max) / 2 + y_center = (y_min + y_max) / 2 + w, h = x_max - x_min, y_max - y_min + + + if not self.align_face: + size = max(w, h) + ss = np.array([0.75, 0.75, 0.85, 0.65]) # predefined expand size + + left = x_center - ss[0] * size + right = x_center + ss[1] * size + top = y_center - ss[2] * size + bottom = y_center + ss[3] * size + + src_pts = np.float32([ + [left, top], + [left, bottom], + [right, top] + ]) + tform = cv2.getAffineTransform(src_pts, self.dst_pts) + else: + src_pts = np.float32([ + (pts68[36] + pts68[39])/2, + (pts68[42] + pts68[45])/2, + pts68[30], + pts68[48], + pts68[54] + ]) + tf = sktrans.SimilarityTransform() + tf.estimate(src_pts, self.dst_pts) + tform = tf.params[0:2,:] + + img_local = cv2.warpAffine(img_raw, tform, (self.input_size,)*2, flags=cv2.INTER_CUBIC) + fake_points2d = np.ones( (1,2), dtype=np.float32) * (self.input_size//2) + + #tform_inv = cv2.invertAffineTransform(tform) + #img_global = cv2.warpAffine(img_local, tform_inv, (img_w, img_h), borderValue=0.0) + #img_global = cv2.resize(img_global, (self.input_size, self.input_size)) + if self.transform is not None: + t = self.transform(image=img_local, keypoints=fake_points2d) + img_local = t['image'] + if self.is_test_aug: + height, width = img_local.shape[:2] + for trans in t["replay"]["transforms"]: + if trans['__class_fullname__']=='ShiftScaleRotate' and trans['applied']: + param = trans['params'] + dx, dy, angle, scale = param['dx'], param['dy'], param['angle'], param['scale'] + center = (width / 2, height / 2) + matrix = cv2.getRotationMatrix2D(center, angle, scale) + matrix[0, 2] += dx * width + matrix[1, 2] += dy * height + new_matrix = np.identity(3) + new_matrix[:2,:3] = matrix + old_tform = np.identity(3) + old_tform[:2,:3] = tform + #new_tform = np.dot(old_tform, new_matrix) + new_tform = np.dot(new_matrix, old_tform) + #print('label_tform:') + #print(label_tform.flatten()) + #print(new_matrix.flatten()) + #print(new_tform.flatten()) + tform = new_tform[:2,:3] + break + #print('trans param:', param) + #img_global = self.transform(image=img_global)['image'] + + tform_tensor = torch.tensor(tform, dtype=torch.float32) + if not self.is_test: + M = np.load(npz_path) + #yaw_gt, pitch_gt, roll_gt = Rotation.from_matrix(M['R_t'][:3, :3].T).as_euler('yxz', degrees=False) + #label_euler = np.array([pitch_gt, yaw_gt, roll_gt]) + #label_translation = M['R_t'][3, :3] + #label_6dof = np.concatenate([label_euler, label_translation]) + #label_6dof = (label_6dof - self.label_6dof_mean) / self.label_6dof_std + #label_6dof_tensor = torch.tensor(label_6dof, dtype=torch.float32) + #label_verts = M['verts'] * 10.0 # roughly [-1, 1] + #label_verts_tensor = torch.tensor(label_verts, dtype=torch.float32) + #return img_local, label_verts_tensor, label_6dof_tensor + label_verts_tensor = torch.tensor(M['verts'], dtype=torch.float32) + label_Rt_tensor = torch.tensor(M['R_t'], dtype=torch.float32) + #return img_local, img_global, label_verts_tensor, label_Rt_tensor, tform_tensor + return img_local, label_verts_tensor, label_Rt_tensor, tform_tensor + else: + #return img_local, img_global, tform_tensor + index_tensor = torch.tensor(index, dtype=torch.long) + return img_local, tform_tensor, index_tensor + + + def __len__(self): + return len(self.df) + +class MXFaceDataset(Dataset): + def __init__(self, cfg, is_train=True, norm_6dof=True, degrees_6dof=False, local_rank=0): + super(MXFaceDataset, self).__init__() + + + self.is_train = is_train + self.data_root = Path(cfg.root_dir) + self.input_size = cfg.input_size + self.transform = get_aug_transform(cfg) + self.local_rank = local_rank + self.use_trainval = cfg.use_trainval + if is_train: + #self.df = pd.read_csv(osp.join(cfg.cache_dir, 'train_list.csv'), dtype={'subject_id': str, 'facial_action': str, 'img_id': str}) + path_imgrec = os.path.join(cfg.cache_dir, 'train.rec') + path_imgidx = os.path.join(cfg.cache_dir, 'train.idx') + self.imgrec = mx.recordio.MXIndexedRecordIO(path_imgidx, path_imgrec, 'r') + self.imgidx = list(self.imgrec.keys) + self.imggroup = [0] * len(self.imgidx) + self.size_train = len(self.imgidx) + if self.use_trainval: + assert not cfg.sampling_hard + path_imgrec = os.path.join(cfg.cache_dir, 'val.rec') + path_imgidx = os.path.join(cfg.cache_dir, 'val.idx') + self.imgrec2 = mx.recordio.MXIndexedRecordIO(path_imgidx, path_imgrec, 'r') + imgidx2 = list(self.imgrec2.keys) + self.imggroup += [1] * len(imgidx2) + self.imgidx += imgidx2 + else: + #self.df = pd.read_csv(osp.join(cfg.cache_dir, 'val_list.csv'), dtype={'subject_id': str, 'facial_action': str, 'img_id': str}) + path_imgrec = os.path.join(cfg.cache_dir, 'val.rec') + path_imgidx = os.path.join(cfg.cache_dir, 'val.idx') + self.imgrec = mx.recordio.MXIndexedRecordIO(path_imgidx, path_imgrec, 'r') + self.imgidx = list(self.imgrec.keys) + self.imggroup = [0] * len(self.imgidx) + self.imgidx = np.array(self.imgidx) + self.imggroup = np.array(self.imggroup) + if cfg.sampling_hard and is_train: + meta = np.load(os.path.join(cfg.cache_dir, 'train.meta.npy')) + assert meta.shape[0]==len(self.imgidx) + new_imgidx = [] + for i in range(len(self.imgidx)): + idx = self.imgidx[i] + assert i==idx + pose = np.abs(meta[i,:2]) + #repeat = np.sum(pose>=35)*3+1 + if np.max(pose)<15: + repeat = 2 + else: + repeat = 1 + new_imgidx += [idx]*repeat + if local_rank==0: + print('new-imgidx:', len(self.imgidx), len(new_imgidx)) + self.imgidx = np.array(new_imgidx) + self.label_6dof_mean = [-0.018197, -0.017891, 0.025348, -0.005368, 0.001176, -0.532206] # mean of pitch, yaw, roll, tx, ty, tz + self.label_6dof_std = [0.314015, 0.271809, 0.081881, 0.022173, 0.048839, 0.065444] # std of pitch, yaw, roll, tx, ty, tz + txt_path = self.data_root / 'resources/projection_matrix.txt' + self.M_proj = np.loadtxt(txt_path, dtype=np.float32) + self.M1 = np.array([ + [400.0, 0, 0, 0], + [ 0, 400.0, 0, 0], + [ 0, 0, 1, 0], + [400.0, 400.0, 0, 1] + ]) + self.dst_pts = np.float32([ + [0, 0], + [0, cfg.input_size- 1], + [cfg.input_size- 1, 0] + ]) + self.norm_6dof = norm_6dof + self.degrees_6dof = degrees_6dof + self.task = cfg.task + self.num_verts = cfg.num_verts + self.loss_pip = cfg.loss_pip + self.net_stride = 32 + if local_rank==0: + logging.info('data_transform_list:%s'%self.transform) + logging.info('len:%d'%len(self.imgidx)) + logging.info('glen:%d'%len(self.imggroup)) + self.is_test_aug = False + + self.enable_flip = cfg.enable_flip + self.flipindex = cfg.flipindex.copy() + self.verts3d_central_index = cfg.verts3d_central_index + + def set_test_aug(self): + if not self.is_test_aug: + from easydict import EasyDict as edict + cfg = edict() + cfg.aug_modes = ['test-aug'] + cfg.input_size = self.input_size + cfg.task = 0 + self.transform = get_aug_transform(cfg) + self.is_test_aug = True + + def __getitem__(self, index): + idx = self.imgidx[index] + group = self.imggroup[index] + if group==0: + imgrec = self.imgrec + elif group==1: + imgrec = self.imgrec2 + elif group==2: + imgrec = self.imgrec3 + + s = imgrec.read_idx(idx) + header, img = mx.recordio.unpack(s) + hlabel = header.label + img = mx.image.imdecode(img).asnumpy() #rgb numpy + + label_verts = np.array(hlabel[:1220*3], dtype=np.float32).reshape(-1,3) + label_Rt = np.array(hlabel[1220*3:1220*3+16], dtype=np.float32).reshape(4,4) + label_tform = np.array(hlabel[1220*3+16:], dtype=np.float32).reshape(2,3) + label_6dof = Rt26dof(label_Rt, self.degrees_6dof) + if self.norm_6dof: + label_6dof = (label_6dof - self.label_6dof_mean) / self.label_6dof_std + label_6dof_tensor = torch.tensor(label_6dof, dtype=torch.float32) + + #img_local = None + img_raw = None + #if self.task==0 or self.task==2: + # img_raw = img[:,self.input_size:,:] + #if self.task==0 or self.task==1 or self.task==3: + # img_local = img[:,:self.input_size,:] + assert img.shape[0]==img.shape[1] and img.shape[0]>=self.input_size + if img.shape[0]>self.input_size: + scale = float(self.input_size) / img.shape[0] + #print('scale:', scale) + #src_pts = np.float32([ + # [0, 0], + # [0, 799], + # [799, 0] + #]) + #tform = cv2.getAffineTransform(src_pts, self.dst_pts) + #new_tform = np.identity(3) + #new_tform[:2,:3] = tform + #label_tform = np.dot(new_tform, label_tform.T).T + + src_pts = np.float32([ + [0, 0, 1], + [0, 799, 1], + [799, 0, 1] + ]) + dst_pts = np.dot(label_tform, src_pts.T).T + dst_pts *= scale + dst_pts = dst_pts.copy() + src_pts = src_pts[:,:2].copy() + #print('index:', index) + #print(src_pts.shape, dst_pts.shape) + #print(label_tform.shape) + #print(src_pts.dtype) + #print(dst_pts.dtype) + tform = cv2.getAffineTransform(src_pts, dst_pts) + label_tform = tform + + img = cv2.resize(img, (self.input_size, self.input_size)) + + img_local = img + need_points2d = (self.task==0 or self.task==3) + + if need_points2d: + ones = np.ones([label_verts.shape[0], 1]) + verts_homo = np.concatenate([label_verts, ones], axis=1) + verts = verts_homo @ label_Rt @ self.M_proj @ self.M1 + w_ = verts[:, [3]] + verts = verts / w_ + points2d = verts[:, :3] + points2d[:, 1] = 800.0 - points2d[:, 1] + verts2d = points2d[:,:2].copy() + points2d[:,2] = 1.0 + points2d = np.dot(label_tform, points2d.T).T + else: + points2d = np.ones( (1,2), dtype=np.float32) * (self.input_size//2) + #if img.shape[0]!=self.input_size: + # assert img.shape[0]>self.input_size + #img = cv2.resize(img, (self.input_size, self.input_size)) + #scale = float(self.input_size) / img.shape[0] + #points2d *= scale + + if self.transform is not None: + if img_raw is not None: + img_raw = self.transform(image=img_raw, keypoints=points2d)['image'] + if img_local is not None: + height, width = img_local.shape[:2] + x = self.transform(image=img_local, keypoints=points2d) + img_local = x['image'] + points2d = x['keypoints'] + points2d = np.array(points2d, dtype=np.float32) + if self.is_test_aug: + for trans in x["replay"]["transforms"]: + if trans['__class_fullname__']=='ShiftScaleRotate' and trans['applied']: + param = trans['params'] + dx, dy, angle, scale = param['dx'], param['dy'], param['angle'], param['scale'] + center = (width / 2, height / 2) + matrix = cv2.getRotationMatrix2D(center, angle, scale) + matrix[0, 2] += dx * width + matrix[1, 2] += dy * height + new_matrix = np.identity(3) + new_matrix[:2,:3] = matrix + old_tform = np.identity(3) + old_tform[:2,:3] = label_tform + #new_tform = np.dot(old_tform, new_matrix) + new_tform = np.dot(new_matrix, old_tform) + #print('label_tform:') + #print(label_tform.flatten()) + #print(new_matrix.flatten()) + #print(new_tform.flatten()) + label_tform = new_tform[:2,:3] + break + #print('trans param:', param) + + + if self.loss_pip: + target_map = np.zeros((self.num_verts, int(self.input_size/self.net_stride), int(self.input_size/self.net_stride))) + target_local_x = np.zeros((self.num_verts, int(self.input_size/self.net_stride), int(self.input_size/self.net_stride))) + target_local_y = np.zeros((self.num_verts, int(self.input_size/self.net_stride), int(self.input_size/self.net_stride))) + target = points2d / self.input_size + target_map, target_local_x, target_local_y = gen_target_pip(target, target_map, target_local_x, target_local_y) + target_map_tensor = torch.tensor(target_map, dtype=torch.float32) + target_x_tensor = torch.tensor(target_local_x, dtype=torch.float32) + target_y_tensor = torch.tensor(target_local_y, dtype=torch.float32) + d['pip_map'] = target_map_tensor + d['pip_x'] = target_x_tensor + d['pip_y'] = target_y_tensor + + if self.is_train and self.enable_flip and np.random.random()<0.5: + #if self.local_rank==0: + # print('XXX:', label_verts[:5,:2]) + img_local = img_local.flip([2]) + x_of_central = 0.0 + #x_of_central = label_verts[self.verts3d_central_index,0] + #x_of_central = np.mean(x_of_central) + label_verts = label_verts[self.flipindex,:] + label_verts[:,0] -= x_of_central + label_verts[:,0] *= -1.0 + label_verts[:,0] += x_of_central + + if need_points2d: + flipped_p2d = points2d[self.flipindex,:].copy() + flipped_p2d[:,0] = self.input_size - 1 - flipped_p2d[:,0] + points2d = flipped_p2d + label_verts_tensor = torch.tensor(label_verts*10.0, dtype=torch.float32) + d = {} + d['img_local'] = img_local + d['verts'] = label_verts_tensor + d['6dof'] = label_6dof_tensor + d['rt'] = torch.tensor(label_Rt, dtype=torch.float32) + if need_points2d: + points2d = points2d / (self.input_size//2) - 1.0 + points2d_tensor = torch.tensor(points2d, dtype=torch.float32) + d['points2d'] = points2d_tensor + + loss_weight = 1.0 + if group!=0: + loss_weight = 0.0 + loss_weight_tensor = torch.tensor(loss_weight, dtype=torch.float32) + d['loss_weight'] = loss_weight_tensor + label_tform_tensor = torch.tensor(label_tform, dtype=torch.float32) + d['tform'] = label_tform_tensor + + #if img_local is None: + # image = (img_raw,) + #elif img_raw is None: + # image = (img_local,) + #else: + # image = (img_local,img_raw) + #ret = image + (label_verts_tensor, label_6dof_tensor, points2d_tensor) + if not self.is_train: + idx_tensor = torch.tensor([idx], dtype=torch.long) + d['idx'] = idx_tensor + d['verts2d'] = torch.tensor(verts2d, dtype=torch.float32) + return d + + + def __len__(self): + return len(self.imgidx) + +def test_dataset1(cfg): + cfg.task = 0 + is_train = False + center_axis = [] + dataset = MXFaceDataset(cfg, is_train=is_train, norm_6dof=False, local_rank=0) + for i in range(len(dataset.flipindex)): + if i==dataset.flipindex[i]: + center_axis.append(i) + print(center_axis) + #dataset.transform = None + print('total:', len(dataset)) + total = len(dataset) + #total = 50 + list_6dof = [] + all_mean_xs = [] + for idx in range(total): + #img_local, img_raw, label_verts, label_6dof, = dataset[idx] + #img_local, img_raw, label_verts, label_6dof, points2d, tform, data_idx = dataset[idx] + #img_local, label_verts, label_6dof, points2d, tform, data_idx = dataset[idx] + d = dataset[idx] + img_local = d['img_local'] + label_verts = d['verts'] + label_6dof = d['6dof'] + points2d = d['points2d'] + label_verts = label_verts.numpy() + label_6dof = label_6dof.numpy() + points2d = points2d.numpy() + #print(img_local.shape, label_verts.shape, label_6dof.shape, points2d.shape) + verts3d = label_verts / 10.0 + xs = [] + for c in center_axis: + _x = verts3d[c,0] + xs.append(_x) + _std = np.std(xs) + print(xs) + print(_std) + #print(np.mean(xs)) + all_mean_xs.append(np.mean(xs)) + if idx%100==0: + print('processing:', idx, np.mean(all_mean_xs)) + #print(label_verts[:3,:], label_6dof) + #list_6dof.append(label_6dof) + #print(image.__class__, label_verts.__class__) + #label = list(label_verts.numpy().flatten()) + list(label_6dof.numpy().flatten()) + #points2d = label_verts2[:,:2] + #points2d = (points2d+1) * 128.0 + #img_local = img_local.numpy() + #img_local = (img_local+1.0) * 128.0 + #draw = img_local.astype(np.uint8).transpose( (1,2,0) )[:,:,::-1].copy() + #for i in range(points2d.shape[0]): + # pt = points2d[i].astype(np.int) + # cv2.circle(draw, pt, 2, (255,0,0), 2) + ##output_path = "outputs/%d_%.3f_%.3f_%.3f.jpg"%(idx, label_6dof[0], label_6dof[1], label_6dof[2]) + #output_path = "outputs/%06d.jpg"%(idx) + #cv2.imwrite(output_path, draw) + #list_6dof = np.array(list_6dof) + #print('MEAN:') + #print(np.mean(list_6dof, axis=0)) + +def test_loader1(cfg): + cfg.task = 0 + is_train = True + dataset = MXFaceDataset(cfg, is_train=is_train, norm_6dof=False, local_rank=0) + loader = DataLoader(dataset, batch_size=64, shuffle=True) + for index, d in enumerate(loader): + #img_local = d['img_local'] + label_verts = d['verts'] + points2d = d['points2d'] + tform = d['tform'] + label_verts /= 10.0 + points2d = (points2d + 1.0) * (cfg.input_size//2) + tform = tform.numpy() + verts = label_verts.numpy() + points2d = points2d.numpy() + print(verts.shape, points2d.shape, tform.shape) + np.save("temp/verts3d.npy", verts) + np.save("temp/points2d.npy", points2d) + np.save("temp/tform.npy", tform) + break + +def test_facedataset1(cfg): + cfg.task = 0 + cfg.input_size = 512 + dataset = FaceDataset(cfg, is_train=True, local_rank=0) + for idx in range(100000): + img_local, label_verts, label_Rt, tform = dataset[idx] + label_Rt = label_Rt.numpy() + if label_Rt[0,0]>1.0: + print(idx, label_Rt.shape) + print(label_Rt) + break + +def test_arcface(cfg): + cfg.task = 0 + is_train = True + dataset = MXFaceDataset(cfg, is_train=is_train, norm_6dof=False, local_rank=0) + loader = DataLoader(dataset, batch_size=1, shuffle=True) + for index, d in enumerate(loader): + img = d['img_local'].numpy() + img /= 2.0 + img += 0.5 + img *= 255.0 + img = img[0] + img = img.transpose( (1,2,0) ) + img = img.astype(np.uint8) + img = cv2.resize(img, (144,144)) + img = img[:,:,::-1] + img = img[8:120,16:128,:] + print(img.shape) + cv2.imwrite("temp/arc_%d.jpg"%index, img) + #np.save("temp/verts3d.npy", verts) + #np.save("temp/points2d.npy", points2d) + #np.save("temp/tform.npy", tform) + if index>100: + break + +if __name__ == "__main__": + from utils.utils_config import get_config + cfg = get_config('configs/r0_a1.py') + #test_loader1(cfg) + #test_facedataset1(cfg) + test_arcface(cfg) + + diff --git a/reconstruction/jmlr/flops.py b/reconstruction/jmlr/flops.py new file mode 100644 index 0000000..605465d --- /dev/null +++ b/reconstruction/jmlr/flops.py @@ -0,0 +1,25 @@ +from ptflops import get_model_complexity_info +import os +import argparse +from utils.utils_config import get_config +from backbones import get_network + +if __name__ == '__main__': + + parser = argparse.ArgumentParser(description='JMLR FLOPs') + parser.add_argument('config', type=str, help='input config file') + args = parser.parse_args() + args = parser.parse_args() + cfg = get_config(args.config) + #backbone = get_model(cfg.network, num_features=cfg.embedding_size, input_size=cfg.input_size, dropout=cfg.dropout, stem_type=cfg.stem_type, fp16=0) + net = get_network(cfg) + macs, params = get_model_complexity_info( + net, (3, cfg.input_size, cfg.input_size), as_strings=True, + print_per_layer_stat=True, verbose=True) + print(macs) + print(params) + + # from torch import distributed + # distributed.AllreduceOptions + # distributed.AllreduceCoalescedOptions + # distributed.all_reduce diff --git a/reconstruction/jmlr/gen_dataset_meta.py b/reconstruction/jmlr/gen_dataset_meta.py new file mode 100644 index 0000000..108cdfe --- /dev/null +++ b/reconstruction/jmlr/gen_dataset_meta.py @@ -0,0 +1,35 @@ +import pickle +import numpy as np +import os +import os.path as osp +import glob +import argparse +import cv2 +import time +import datetime +import pickle +import sklearn +import mxnet as mx +from utils.utils_config import get_config +from dataset import MXFaceDataset, Rt26dof + +if __name__ == "__main__": + cfg = get_config('configs/s1.py') + cfg.task = 0 + save_path = os.path.join(cfg.cache_dir, 'train.meta') + assert not osp.exists(save_path) + dataset = MXFaceDataset(cfg, is_train=True, norm_6dof=False, degrees_6dof=True, local_rank=0) + #dataset.transform = None + print('total:', len(dataset)) + total = len(dataset) + meta = np.zeros( (total, 3), dtype=np.float32 ) + for idx in range(total): + #image, label_verts, label_6dof = dataset[idx] + #img_raw, img_local, label_verts, label_Rt, tform = dataset[idx] + img, label_verts, label_6dof, label_points2d, _, _ = dataset[idx] + pose = label_6dof.numpy()[:3] + print(idx, pose) + meta[idx] = pose + + np.save(save_path, meta) + diff --git a/reconstruction/jmlr/inference_simple.py b/reconstruction/jmlr/inference_simple.py new file mode 100644 index 0000000..c926f5d --- /dev/null +++ b/reconstruction/jmlr/inference_simple.py @@ -0,0 +1,189 @@ + +import os +import time +import timm +import glob +import numpy as np +import os.path as osp +import cv2 + +import torch +import torch.distributed as dist +from torch import nn +from pathlib import Path +from backbones import get_network +from skimage import transform as sktrans + + +def solver_rigid(pts_3d , pts_2d , camera_matrix): + # pts_3d Nx3 + # pts_2d Nx2 + # camera_matrix 4x4 + dist_coeffs = np.zeros((4,1)) + pts_3d = pts_3d.copy() + pts_2d = pts_2d.copy() + success, rotation_vector, translation_vector = cv2.solvePnP(pts_3d, pts_2d, camera_matrix.copy(), dist_coeffs, flags=0) + assert success + R, _ = cv2.Rodrigues(rotation_vector) + R = R.T + R[:,1:3] *= -1 + T = translation_vector.flatten() + T[1:] *= -1 + + return R,T + + +class JMLRInference(nn.Module): + def __init__(self, cfg, local_rank=0): + super(JMLRInference, self).__init__() + backbone = get_network(cfg) + if cfg.ckpt is None: + ckpts = list(glob.glob(osp.join(cfg.output, "backbone*.pth"))) + backbone_pth = sorted(ckpts)[-1] + else: + backbone_pth = cfg.ckpt + if local_rank==0: + print(backbone_pth) + backbone_ckpt = torch.load(backbone_pth, map_location=torch.device(local_rank)) + if 'model' in backbone_ckpt: + backbone_ckpt = backbone_ckpt['model'] + backbone.load_state_dict(backbone_ckpt) + backbone.eval() + backbone.requires_grad_(False) + self.num_verts = cfg.num_verts + self.input_size = cfg.input_size + self.data_root = Path(cfg.root_dir) + txt_path = self.data_root / 'resources/projection_matrix.txt' + self.M_proj = np.loadtxt(txt_path, dtype=np.float32) + + def set_raw_image_size(self, width, height): + w = width / 2.0 + h = height / 2.0 + M1 = np.array([ + [w, 0, 0, 0], + [ 0, h, 0, 0], + [ 0, 0, 1, 0], + [w, h, 0, 1] + ]) + camera_matrix = self.M_proj @ M1 + camera_matrix = camera_matrix[:3,:3].T + camera_matrix[0,2] = w + camera_matrix[1,2] = h + self.camera_matrix = camera_matrix + + + def forward(self, img_local, is_flip=False): + if is_flip: + img_local = img_local.flip([3]) + pred = self.backbone(img_local) + pred1 = pred[:,:1220*3] + pred2 = pred[:,1220*3:] + meta = {'flip': is_flip} + return pred1, pred2, meta + + + def convert_verts(self, pred1, meta): + is_flip = meta['flip'] + pred1 = pred1.cpu().numpy() + pred1 = pred1[:,:1220*3] + pred_verts = pred1.reshape(-1,1220,3) / 10.0 + if is_flip: + pred_verts = pred_verts[:,self.flipindex,:] + pred_verts[:,:,0] *= -1.0 + return pred_verts + + def convert_2d(self, pred2, tforms, meta): + is_flip = meta['flip'] + tforms = tforms.cpu().numpy() + pred2 = pred2.cpu().numpy() + points2d = (pred2.reshape(-1,1220,2)+1.0) * self.input_size//2 + if is_flip: + points2d = points2d[:,self.flipindex,:] + points2d[:,:,0] = self.input_size - 1 - points2d[:,:,0] + B = points2d.shape[0] + points2de = np.ones( (points2d.shape[0], points2d.shape[1], 3), dtype=points2d.dtype) + points2de[:,:,:2] = points2d + verts2d = np.zeros((B,1220,2), dtype=np.float32) + for n in range(B): + tform = tforms[n] + tform_inv = cv2.invertAffineTransform(tform) + _points2d = np.dot(tform_inv, points2de[n].T).T + verts2d[n] = _points2d + return verts2d, points2d + + def solve(self, verts3d, verts2d): + B = verts3d.shape[0] + R = np.zeros([B, 3, 3], dtype=np.float32) + t = np.zeros([B, 1, 3], dtype=np.float32) + for n in range(B): + _R, _t = solver_rigid(verts3d[n], verts2d[n], self.camera_matrix) + R[n] = _R + t[n,0] = _t + return R, t + + def solve_one(self, verts3d, verts2d): + R, t = solver_rigid(verts3d, verts2d, self.camera_matrix) + return R, t + + +def get(net, img, keypoints): + dst_pts = np.array([ + [38.2946, 51.6963], + [73.5318, 51.5014], + [56.0252, 71.7366], + [41.5493, 92.3655], + [70.7299, 92.2041] ], dtype=np.float32 ) + input_size = 256 + local_rank = 0 + + new_size = 144 + dst_pts[:,0] += ((new_size-112)//2) + dst_pts[:,1] += 8 + dst_pts[:,:] *= (input_size/float(new_size)) + tf = sktrans.SimilarityTransform() + tf.estimate(keypoints, dst_pts) + tform = tf.params[0:2,:] + img = cv2.warpAffine(img, tform, (input_size,)*2) + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + img = np.transpose(img, (2, 0, 1)) + img = torch.from_numpy(img).unsqueeze(0).float() + img.div_(255).sub_(0.5).div_(0.5) + img_local = img.to(local_rank) + with torch.no_grad(): + pred1, pred2, meta = net(img_local, is_flip=False) + pred_verts = net.convert_verts(pred1, meta) + tform = torch.from_numpy(tform.reshape(1,2,3)) + pred_verts2d, pred_points2d = net.convert_2d(pred2, tform, meta) + return pred_verts[0], pred_verts2d[0] + + +if __name__ == "__main__": + import argparse + from utils.utils_config import get_config + from insightface.app import FaceAnalysis + parser = argparse.ArgumentParser(description='JMLR inference') + #parser.add_argument('config', type=str, help='config file') + config_file = 'configs/s1.py' + args = parser.parse_args() + cfg = get_config(config_file) + cfg2 = None + local_rank = 0 + img = cv2.imread('sample.jpg') + net = JMLRInference(cfg, local_rank) + print(img.shape) + net.set_raw_image_size(img.shape[1], img.shape[0]) + net = net.to(local_rank) + net.eval() + app = FaceAnalysis(allowed_modules='detection') + app.prepare(ctx_id=0, det_size=(640,640), det_thresh=0.5) + draw = img.copy() + faces = app.get(img) + for face in faces: + verts3d, verts2d = get(net, img, face.kps) + R, t = net.solve_one(verts3d, verts2d) + print(verts3d.shape, verts2d.shape, R.shape, t.shape) + for i in range(verts2d.shape[0]): + pt = verts2d[i].astype(np.int) + cv2.circle(draw, pt, 2, (255,0,0), 2) + cv2.imwrite('./draw.jpg', draw) + diff --git a/reconstruction/jmlr/losses.py b/reconstruction/jmlr/losses.py new file mode 100644 index 0000000..8659f0c --- /dev/null +++ b/reconstruction/jmlr/losses.py @@ -0,0 +1,111 @@ +import torch +from torch import nn +import torch.nn.functional as F +import kornia +import numpy as np + +#def loss_l1(a, b): + #_loss = torch.abs(a - b) + #_loss = torch.mean(_loss, dim=1) + ##if epoch>4 and cfg.loss_hard: + ## _loss, _ = torch.topk(_loss, k=int(cfg.batch_size*0.3)) + #_loss = torch.mean(_loss) + #return _loss + + + +def loss_pip(outputs_map, outputs_local_x, outputs_local_y, labels_map, labels_local_x, labels_local_y): + + tmp_batch, tmp_channel, tmp_height, tmp_width = outputs_map.size() + labels_map = labels_map.view(tmp_batch*tmp_channel, -1) + labels_max_ids = torch.argmax(labels_map, 1) + labels_max_ids = labels_max_ids.view(-1, 1) + + #print('TTT:', outputs_local_x.shape, tmp_batch, tmp_channel) + + outputs_local_x = outputs_local_x.reshape(tmp_batch*tmp_channel, -1) + outputs_local_x_select = torch.gather(outputs_local_x, 1, labels_max_ids) + outputs_local_y = outputs_local_y.reshape(tmp_batch*tmp_channel, -1) + outputs_local_y_select = torch.gather(outputs_local_y, 1, labels_max_ids) + + labels_local_x = labels_local_x.view(tmp_batch*tmp_channel, -1) + labels_local_x_select = torch.gather(labels_local_x, 1, labels_max_ids) + labels_local_y = labels_local_y.view(tmp_batch*tmp_channel, -1) + labels_local_y_select = torch.gather(labels_local_y, 1, labels_max_ids) + + labels_map = labels_map.view(tmp_batch, tmp_channel, tmp_height, tmp_width) + loss_map = F.mse_loss(outputs_map, labels_map) + loss_x = F.l1_loss(outputs_local_x_select, labels_local_x_select) + loss_y = F.l1_loss(outputs_local_y_select, labels_local_y_select) + return loss_map, loss_x, loss_y + +def eye_like(x: torch.Tensor, n: int) -> torch.Tensor: + return torch.eye(n, n, dtype=x.dtype, device=x.device).unsqueeze(0).repeat(x.shape[0], 1, 1) + +class ProjectLoss(nn.Module): + + def __init__(self,M_proj): + super(ProjectLoss, self).__init__() + img_w = 800 + img_h = 800 + M1 = np.array([ + [img_w/2, 0, 0, 0], + [ 0, img_h/2, 0, 0], + [ 0, 0, 1, 0], + [img_w/2, img_h/2, 0, 1] + ]) + M = M_proj @ M1 + M = M.astype(np.float32) + self.register_buffer('M', torch.from_numpy(M)) + + camera_matrix = M[:3,:3].T.copy() + camera_matrix[0,2] = 400 + camera_matrix[1,2] = 400 + camera_matrix[2,2] = 1 + intrinsics = np.array([camera_matrix]).astype(np.float64) + self.register_buffer('intrinsics', torch.from_numpy(intrinsics)) + + + self.eps = 1e-5 + #self.projector = Reprojector(img_w,img_h,M_proj) + #self.solver = PnPSolver(self.projector.M.numpy()) + #self.loss_fn = torch.nn.MSELoss(reduce=False, size_average=False) + #self.loss_fn = torch.nn.MSELoss() + self.loss_fn = torch.nn.L1Loss() + + + def forward(self,verts3d, points2d, affine): + # pred_2d_lmks Batch*N*2 + # verts Batch*N*3 + ones = torch.ones([points2d.shape[0] , points2d.shape[1], 1],device=points2d.device) + verts_homo = torch.cat((points2d, ones), 2) + K = eye_like(affine,3) + K[:,:2,:3] = affine + inv_k = torch.linalg.inv(K) + inv_k@verts_homo.permute(0,2,1) + points2d_inv = inv_k@verts_homo.permute(0,2,1) + points2d_inv = points2d_inv.permute(0,2,1)[:,:,:2] + + intrinsics = self.intrinsics.repeat([verts3d.shape[0],1,1 ]) + #print(verts3d.double().shape) + #print(points2d.double().shape) + #print(intrinsics.shape) + RT_ = kornia.geometry.solve_pnp_dlt(verts3d.double(), points2d_inv.double(), intrinsics,svd_eps=self.eps) + RT_ = RT_.float() + RT = eye_like(verts3d,4) +# RT[:,1:3,:] *=-1 + RT[:,:3,:] = RT_ + RT = RT.permute(0,2,1) + RT[:,:,:2] *= -1 + + ones = torch.ones([verts3d.shape[0] , verts3d.shape[1], 1],device=verts3d.device) + verts_homo = torch.cat((verts3d, ones), 2) + M = self.M.repeat([verts3d.shape[0],1,1 ]) + verts = verts_homo @ RT @ M + w_ = verts[:,:, [3]] + verts = verts / w_ + reproject_points2d = verts[:,:, :2] + loss = self.loss_fn(reproject_points2d , points2d_inv) + + return loss + diff --git a/reconstruction/jmlr/lr_scheduler.py b/reconstruction/jmlr/lr_scheduler.py new file mode 100644 index 0000000..3e5461d --- /dev/null +++ b/reconstruction/jmlr/lr_scheduler.py @@ -0,0 +1,93 @@ +import torch +from torch.optim.lr_scheduler import _LRScheduler + + +class PolyScheduler(_LRScheduler): + def __init__(self, + optimizer, + base_lr, + max_steps, + warmup_steps, + last_epoch=-1): + self.base_lr = base_lr + self.warmup_lr_init = 0.0001 + self.max_steps: int = max_steps + self.warmup_steps: int = warmup_steps + self.power = 2 + super(PolyScheduler, self).__init__(optimizer, last_epoch, False) + + def get_warmup_lr(self): + alpha = float(self.last_epoch) / float(self.warmup_steps) + #_lr = max(self.base_lr * alpha, self.warmup_lr_init) + _lr = self.base_lr * alpha + return [_lr for _ in self.optimizer.param_groups] + + def get_lr(self): + if self.last_epoch == -1: + return [self.warmup_lr_init for _ in self.optimizer.param_groups] + if self.last_epoch < self.warmup_steps: + return self.get_warmup_lr() + else: + alpha = pow( + 1 - float(self.last_epoch - self.warmup_steps) / + float(self.max_steps - self.warmup_steps), + self.power, + ) + return [self.base_lr * alpha for _ in self.optimizer.param_groups] + +class StepScheduler(_LRScheduler): + def __init__(self, + optimizer, + base_lr, + lr_steps, + warmup_steps, + last_epoch=-1): + self.base_lr = base_lr + self.warmup_lr_init = 0.0001 + self.lr_steps = lr_steps + self.warmup_steps: int = warmup_steps + super(StepScheduler, self).__init__(optimizer, last_epoch, False) + + def get_warmup_lr(self): + alpha = float(self.last_epoch) / float(self.warmup_steps) + #_lr = max(self.base_lr * alpha, self.warmup_lr_init) + _lr = self.base_lr * alpha + return [_lr for _ in self.optimizer.param_groups] + + def get_lr(self): + if self.last_epoch == -1: + return [self.warmup_lr_init for _ in self.optimizer.param_groups] + if self.last_epoch < self.warmup_steps: + return self.get_warmup_lr() + else: + alpha = 0.1 ** len([m for m in self.lr_steps if m <= self.last_epoch]) + return [self.base_lr * alpha for _ in self.optimizer.param_groups] + + + +def get_scheduler(opt, cfg): + if cfg.lr_func is not None: + scheduler = torch.optim.lr_scheduler.LambdaLR( + optimizer=opt, lr_lambda=cfg.lr_func) + else: + #total_batch_size = cfg.batch_size * cfg.world_size + #warmup_steps = cfg.num_images // total_batch_size * cfg.warmup_epochs + #total_steps = cfg.num_images // total_batch_size * cfg.num_epochs + + if cfg.lr_steps is None: + scheduler = PolyScheduler( + optimizer=opt, + base_lr=cfg.lr, + max_steps=cfg.total_steps, + warmup_steps=cfg.warmup_steps, + ) + else: + scheduler = StepScheduler( + optimizer=opt, + base_lr=cfg.lr, + lr_steps=cfg.lr_steps, + warmup_steps=cfg.warmup_steps, + ) + + return scheduler + diff --git a/reconstruction/jmlr/rec_builder.py b/reconstruction/jmlr/rec_builder.py new file mode 100644 index 0000000..d914f12 --- /dev/null +++ b/reconstruction/jmlr/rec_builder.py @@ -0,0 +1,113 @@ +import pickle +import numpy as np +import os +import os.path as osp +import glob +import argparse +import cv2 +import time +import datetime +import pickle +import sklearn +import mxnet as mx +from utils.utils_config import get_config +from dataset import FaceDataset, Rt26dof + +class RecBuilder(): + def __init__(self, path, image_size=(112, 112), is_train=True): + self.path = path + self.image_size = image_size + self.widx = 0 + self.wlabel = 0 + self.max_label = -1 + #assert not osp.exists(path), '%s exists' % path + if is_train: + rec_file = osp.join(path, 'train.rec') + idx_file = osp.join(path, 'train.idx') + else: + rec_file = osp.join(path, 'val.rec') + idx_file = osp.join(path, 'val.idx') + #assert not osp.exists(rec_file), '%s exists' % rec_file + if not osp.exists(path): + os.makedirs(path) + self.writer = mx.recordio.MXIndexedRecordIO(idx_file, + rec_file, + 'w') + self.meta = [] + + def add(self, imgs): + #!!! img should be BGR!!!! + #assert label >= 0 + #assert label > self.last_label + assert len(imgs) > 0 + label = self.wlabel + for img in imgs: + idx = self.widx + image_meta = {'image_index': idx, 'image_classes': [label]} + header = mx.recordio.IRHeader(0, label, idx, 0) + if isinstance(img, np.ndarray): + s = mx.recordio.pack_img(header,img,quality=95,img_fmt='.jpg') + else: + s = mx.recordio.pack(header, img) + self.writer.write_idx(idx, s) + self.meta.append(image_meta) + self.widx += 1 + self.max_label = label + self.wlabel += 1 + return label + + + def add_image(self, img, label): + #!!! img should be BGR!!!! + #assert label >= 0 + #assert label > self.last_label + idx = self.widx + header = mx.recordio.IRHeader(0, label, idx, 0) + if isinstance(img, np.ndarray): + s = mx.recordio.pack_img(header,img,quality=100,img_fmt='.jpg') + else: + s = mx.recordio.pack(header, img) + self.writer.write_idx(idx, s) + self.widx += 1 + + def close(self): + print('stat:', self.widx, self.wlabel) + +if __name__ == "__main__": + cfg = get_config('configs/s1.py') + cfg.task = 0 + cfg.input_size = 512 + for is_train in [True, False]: + dataset = FaceDataset(cfg, is_train=is_train, local_rank=0) + dataset.transform = None + writer = RecBuilder(cfg.cache_dir, is_train=is_train) + #writer = RecBuilder("temp", is_train=is_train) + print('total:', len(dataset)) + meta = np.zeros( (len(dataset), 3), dtype=np.float32 ) + subset_name = 'train' if is_train else 'val' + meta_path = osp.join(cfg.cache_dir, '%s.meta'%subset_name) + for idx in range(len(dataset)): + #img_local, img_global, label_verts, label_Rt, tform = dataset[idx] + img_local, label_verts, label_Rt, tform = dataset[idx] + label_verts = label_verts.numpy() + label_Rt = label_Rt.numpy() + tform = tform.numpy() + label_6dof = Rt26dof(label_Rt, True) + pose = label_6dof[:3] + meta[idx] = pose + #print(image.shape, label_verts.shape, label_6dof.shape) + #print(image.__class__, label_verts.__class__) + img_local = img_local[:,:,::-1] + #img_global = img_global[:,:,::-1] + #image = np.concatenate( (img_local, img_global), axis=1 ) + image = img_local + label = list(label_verts.flatten()) + list(label_Rt.flatten()) + list(tform.flatten()) + assert len(label)==1220*3+16+6 + writer.add_image(image, label) + if idx%100==0: + print('processing:', idx, image.shape, len(label)) + if idx<10: + cv2.imwrite("temp/%d.jpg"%idx, image) + writer.close() + np.save(meta_path, meta) + diff --git a/reconstruction/jmlr/train.py b/reconstruction/jmlr/train.py new file mode 100644 index 0000000..3b77d82 --- /dev/null +++ b/reconstruction/jmlr/train.py @@ -0,0 +1,280 @@ +import argparse +import logging +import os +import time +import timm +import glob +import numpy as np +import os.path as osp + +import torch +import torch.distributed as dist +from torch import nn +import torch.nn.functional as F +import torch.utils.data.distributed +from torch.nn.utils import clip_grad_norm_ +from dataset import FaceDataset, DataLoaderX, MXFaceDataset, get_tris + +from backbones import get_network +from utils.utils_amp import MaxClipGradScaler +from utils.utils_callbacks import CallBackVerification, CallBackLogging, CallBackModelCheckpoint +from utils.utils_logging import AverageMeter, init_logging +from utils.utils_config import get_config +from lr_scheduler import get_scheduler +from timm.optim.optim_factory import create_optimizer + + + + + +def main(args): + cfg = get_config(args.config) + if not cfg.tf32: + torch.backends.cuda.matmul.allow_tf32 = False + torch.backends.cudnn.allow_tf32 = False + else: + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + + try: + world_size = int(os.environ['WORLD_SIZE']) + rank = int(os.environ['RANK']) + dist.init_process_group('nccl') + except KeyError: + world_size = 1 + rank = 0 + dist.init_process_group(backend='nccl', init_method="tcp://127.0.0.1:12584", rank=rank, world_size=world_size) + + + local_rank = args.local_rank + torch.cuda.set_device(local_rank) + + if not os.path.exists(cfg.output) and rank is 0: + os.makedirs(cfg.output) + else: + time.sleep(2) + + log_root = logging.getLogger() + init_logging(log_root, rank, cfg.output) + if rank==0: + logging.info(args) + logging.info(cfg) + print(cfg.flipindex.shape, cfg.flipindex[400:410]) + train_set = MXFaceDataset(cfg=cfg, is_train=True, local_rank=local_rank) + cfg.num_images = len(train_set) + cfg.world_size = world_size + total_batch_size = cfg.batch_size * cfg.world_size + epoch_steps = cfg.num_images // total_batch_size + cfg.warmup_steps = epoch_steps * cfg.warmup_epochs + if cfg.max_warmup_steps>0: + cfg.warmup_steps = min(cfg.max_warmup_steps, cfg.warmup_steps) + cfg.total_steps = epoch_steps * cfg.num_epochs + if cfg.lr_epochs is not None: + cfg.lr_steps = [m*epoch_steps for m in cfg.lr_epochs] + else: + cfg.lr_steps = None + train_sampler = torch.utils.data.distributed.DistributedSampler( + train_set, shuffle=True) + + train_loader = torch.utils.data.DataLoader( + dataset=train_set, batch_size=cfg.batch_size, + sampler=train_sampler, num_workers=4, pin_memory=False, drop_last=True) + + + net = get_network(cfg).to(local_rank) + + + if cfg.resume: + try: + + ckpts = list(glob.glob(osp.join(cfg.resume_path, "backbone*.pth"))) + backbone_pth = sorted(ckpts)[-1] + backbone_ckpt = torch.load(backbone_pth, map_location=torch.device(local_rank)) + net.load_state_dict(backbone_ckpt['model']) + if rank==0: + logging.info("backbone resume successfully! %s"%backbone_pth) + except (FileNotFoundError, KeyError, IndexError, RuntimeError): + logging.info("resume fail!!") + raise RuntimeError + + + net = torch.nn.parallel.DistributedDataParallel( + module=net, broadcast_buffers=False, device_ids=[local_rank]) + net.train() + + + + + if cfg.opt=='sgd': + opt = torch.optim.SGD( + params=[ + {"params": net.parameters()}, + ], + lr=cfg.lr, momentum=0.9, weight_decay=cfg.weight_decay) + elif cfg.opt=='adam': + opt = torch.optim.Adam( + params=[ + {"params": net.parameters()}, + ], + lr=cfg.lr) + elif cfg.opt=='adamw': + opt = torch.optim.AdamW( + params=[ + {"params": net.parameters()}, + ], + lr=cfg.lr, weight_decay=cfg.weight_decay) + + + scheduler = get_scheduler(opt, cfg) + if cfg.resume: + if rank==0: + logging.info(opt) + + + + if cfg.resume: + for g in opt_pfc.param_groups: + for key in ['lr', 'initial_lr']: + g[key] = cfg.lr + + + + start_epoch = 0 + total_step = cfg.total_steps + if rank==0: + logging.info(opt) + logging.info("Total Step is: %d" % total_step) + + + loss = { + 'Loss': AverageMeter(), + + } + + global_step = 0 + grad_amp = None + if cfg.fp16>0: + if cfg.fp16==1: + grad_amp = torch.cuda.amp.grad_scaler.GradScaler(growth_interval=100) + elif cfg.fp16==2: + grad_amp = MaxClipGradScaler(64, 1024, growth_interval=200) + elif cfg.fp16==3: + grad_amp = MaxClipGradScaler(4, 8, growth_interval=200) + else: + assert 'fp16 mode not set' + + callback_checkpoint = CallBackModelCheckpoint(rank, cfg) + + callback_checkpoint(global_step, net, opt) + + callback_logging = CallBackLogging(50, rank, total_step, cfg.batch_size, world_size, None) + + l1loss = nn.L1Loss() + + + tris = get_tris(cfg) + tri_index = torch.tensor(tris, dtype=torch.long).to(local_rank) + + for epoch in range(start_epoch, cfg.num_epochs): + train_sampler.set_epoch(epoch) + for step, value in enumerate(train_loader): + global_step += 1 + img = value['img_local'].to(local_rank) + dloss = {} + assert cfg.task==0 + label_verts = value['verts'].to(local_rank) + label_points2d = value['points2d'].to(local_rank) + preds = net(img) + + pred_verts, pred_points2d = preds.split([1220*3, 1220*2], dim=1) + pred_verts = pred_verts.view(cfg.batch_size, 1220, 3) + pred_points2d = pred_points2d.view(cfg.batch_size, 1220, 2) + if not cfg.use_rtloss: + loss1 = F.l1_loss(pred_verts, label_verts) + else: + label_Rt = value['rt'].to(local_rank) + _ones = torch.ones([pred_verts.shape[0], 1220, 1], device=pred_verts.device) + pred_verts = torch.cat([pred_verts/10, _ones], dim=2) + pred_verts = torch.bmm(pred_verts,label_Rt) * 10.0 + label_verts = torch.cat([label_verts/10, _ones], dim=2) + label_verts = torch.bmm(label_verts,label_Rt) * 10.0 + loss1 = F.l1_loss(pred_verts, label_verts) + + loss2 = F.l1_loss(pred_points2d, label_points2d) + loss3d = loss1 * cfg.lossw_verts3d + loss2d = loss2 * cfg.lossw_verts2d + dloss['Loss'] = loss3d + loss2d + dloss['Loss3D'] = loss3d + dloss['Loss2D'] = loss2d + + if cfg.loss_bone3d: + bone_losses = [] + for i in range(3): + pred_verts_x = pred_verts[:,tri_index[:,i%3],:] + pred_verts_y = pred_verts[:,tri_index[:,(i+1)%3],:] + label_verts_x = label_verts[:,tri_index[:,i%3],:] + label_verts_y = label_verts[:,tri_index[:,(i+1)%3],:] + dist_pred = torch.norm(pred_verts_x - pred_verts_y, p=2, dim=-1, keepdim=False) + dist_label = torch.norm(label_verts_x - label_verts_y, p=2, dim=-1, keepdim=False) + bone_losses.append(F.l1_loss(dist_pred, dist_label) * cfg.lossw_bone3d) + dloss['LossBone3d'] = sum(bone_losses) + + + if cfg.loss_bone2d: + bone_losses = [] + for i in range(3): + pred_points2d_x = pred_points2d[:,tri_index[:,i%3],:] + pred_points2d_y = pred_points2d[:,tri_index[:,(i+1)%3],:] + label_points2d_x = label_points2d[:,tri_index[:,i%3],:] + label_points2d_y = label_points2d[:,tri_index[:,(i+1)%3],:] + dist_pred = torch.norm(pred_points2d_x - pred_points2d_y, p=2, dim=-1, keepdim=False) + dist_label = torch.norm(label_points2d_x - label_points2d_y, p=2, dim=-1, keepdim=False) + bone_losses.append(F.l1_loss(dist_pred, dist_label) * cfg.lossw_bone2d) + dloss['LossBone2d'] = sum(bone_losses) + + iter_loss = dloss['Loss'] + + if cfg.fp16>0: + grad_amp.scale(iter_loss).backward() + grad_amp.unscale_(opt) + if cfg.fp16<2: + torch.nn.utils.clip_grad_norm_(net.parameters(), 5) + grad_amp.step(opt) + grad_amp.update() + else: + iter_loss.backward() + opt.step() + + + opt.zero_grad() + + if cfg.lr_func is None: + scheduler.step() + + with torch.no_grad(): + loss['Loss'].update(iter_loss.item(), 1) + for k in dloss: + if k=='Loss': + continue + v = dloss[k].item() + if k not in loss: + loss[k] = AverageMeter() + loss[k].update(v, 1) + + callback_logging(global_step, loss, epoch, cfg.fp16, grad_amp, opt) + + if cfg.lr_func is not None: + scheduler.step() + + callback_checkpoint(9999, net, opt) + dist.destroy_process_group() + + +if __name__ == "__main__": + torch.backends.cudnn.benchmark = True + parser = argparse.ArgumentParser(description='JMLR Training') + parser.add_argument('config', type=str, help='config file') + parser.add_argument('--local_rank', type=int, default=0, help='local_rank') + args_ = parser.parse_args() + main(args_) + diff --git a/reconstruction/jmlr/utils/__init__.py b/reconstruction/jmlr/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/reconstruction/jmlr/utils/plot.py b/reconstruction/jmlr/utils/plot.py new file mode 100644 index 0000000..ccc588e --- /dev/null +++ b/reconstruction/jmlr/utils/plot.py @@ -0,0 +1,72 @@ +# coding: utf-8 + +import os +from pathlib import Path + +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +from menpo.visualize.viewmatplotlib import sample_colours_from_colourmap +from prettytable import PrettyTable +from sklearn.metrics import roc_curve, auc + +image_path = "/data/anxiang/IJB_release/IJBC" +files = [ + "./ms1mv3_arcface_r100/ms1mv3_arcface_r100/ijbc.npy" +] + + +def read_template_pair_list(path): + pairs = pd.read_csv(path, sep=' ', header=None).values + t1 = pairs[:, 0].astype(np.int) + t2 = pairs[:, 1].astype(np.int) + label = pairs[:, 2].astype(np.int) + return t1, t2, label + + +p1, p2, label = read_template_pair_list( + os.path.join('%s/meta' % image_path, + '%s_template_pair_label.txt' % 'ijbc')) + +methods = [] +scores = [] +for file in files: + methods.append(file.split('/')[-2]) + scores.append(np.load(file)) + +methods = np.array(methods) +scores = dict(zip(methods, scores)) +colours = dict( + zip(methods, sample_colours_from_colourmap(methods.shape[0], 'Set2'))) +x_labels = [10 ** -6, 10 ** -5, 10 ** -4, 10 ** -3, 10 ** -2, 10 ** -1] +tpr_fpr_table = PrettyTable(['Methods'] + [str(x) for x in x_labels]) +fig = plt.figure() +for method in methods: + fpr, tpr, _ = roc_curve(label, scores[method]) + roc_auc = auc(fpr, tpr) + fpr = np.flipud(fpr) + tpr = np.flipud(tpr) # select largest tpr at same fpr + plt.plot(fpr, + tpr, + color=colours[method], + lw=1, + label=('[%s (AUC = %0.4f %%)]' % + (method.split('-')[-1], roc_auc * 100))) + tpr_fpr_row = [] + tpr_fpr_row.append("%s-%s" % (method, "IJBC")) + for fpr_iter in np.arange(len(x_labels)): + _, min_index = min( + list(zip(abs(fpr - x_labels[fpr_iter]), range(len(fpr))))) + tpr_fpr_row.append('%.2f' % (tpr[min_index] * 100)) + tpr_fpr_table.add_row(tpr_fpr_row) +plt.xlim([10 ** -6, 0.1]) +plt.ylim([0.3, 1.0]) +plt.grid(linestyle='--', linewidth=1) +plt.xticks(x_labels) +plt.yticks(np.linspace(0.3, 1.0, 8, endpoint=True)) +plt.xscale('log') +plt.xlabel('False Positive Rate') +plt.ylabel('True Positive Rate') +plt.title('ROC on IJB') +plt.legend(loc="lower right") +print(tpr_fpr_table) diff --git a/reconstruction/jmlr/utils/utils_amp.py b/reconstruction/jmlr/utils/utils_amp.py new file mode 100644 index 0000000..1a2286f --- /dev/null +++ b/reconstruction/jmlr/utils/utils_amp.py @@ -0,0 +1,82 @@ +from typing import Dict, List + +import torch +#from torch._six import container_abcs +import collections.abc as container_abcs +from torch.cuda.amp import GradScaler + + +class _MultiDeviceReplicator(object): + """ + Lazily serves copies of a tensor to requested devices. Copies are cached per-device. + """ + + def __init__(self, master_tensor: torch.Tensor) -> None: + assert master_tensor.is_cuda + self.master = master_tensor + self._per_device_tensors: Dict[torch.device, torch.Tensor] = {} + + def get(self, device) -> torch.Tensor: + retval = self._per_device_tensors.get(device, None) + if retval is None: + retval = self.master.to(device=device, non_blocking=True, copy=True) + self._per_device_tensors[device] = retval + return retval + + +class MaxClipGradScaler(GradScaler): + def __init__(self, init_scale, max_scale: float, growth_interval=100): + GradScaler.__init__(self, init_scale=init_scale, growth_interval=growth_interval) + self.max_scale = max_scale + + def scale_clip(self): + if self.get_scale() == self.max_scale: + self.set_growth_factor(1) + elif self.get_scale() < self.max_scale: + self.set_growth_factor(2) + elif self.get_scale() > self.max_scale: + self._scale.fill_(self.max_scale) + self.set_growth_factor(1) + + def scale(self, outputs): + """ + Multiplies ('scales') a tensor or list of tensors by the scale factor. + + Returns scaled outputs. If this instance of :class:`GradScaler` is not enabled, outputs are returned + unmodified. + + Arguments: + outputs (Tensor or iterable of Tensors): Outputs to scale. + """ + if not self._enabled: + return outputs + self.scale_clip() + # Short-circuit for the common case. + if isinstance(outputs, torch.Tensor): + assert outputs.is_cuda + if self._scale is None: + self._lazy_init_scale_growth_tracker(outputs.device) + assert self._scale is not None + return outputs * self._scale.to(device=outputs.device, non_blocking=True) + + # Invoke the more complex machinery only if we're treating multiple outputs. + stash: List[_MultiDeviceReplicator] = [] # holds a reference that can be overwritten by apply_scale + + def apply_scale(val): + if isinstance(val, torch.Tensor): + assert val.is_cuda + if len(stash) == 0: + if self._scale is None: + self._lazy_init_scale_growth_tracker(val.device) + assert self._scale is not None + stash.append(_MultiDeviceReplicator(self._scale)) + return val * stash[0].get(val.device) + elif isinstance(val, container_abcs.Iterable): + iterable = map(apply_scale, val) + if isinstance(val, list) or isinstance(val, tuple): + return type(val)(iterable) + else: + return iterable + else: + raise ValueError("outputs must be a Tensor or an iterable of Tensors") + return apply_scale(outputs) diff --git a/reconstruction/jmlr/utils/utils_callbacks.py b/reconstruction/jmlr/utils/utils_callbacks.py new file mode 100644 index 0000000..ab740ec --- /dev/null +++ b/reconstruction/jmlr/utils/utils_callbacks.py @@ -0,0 +1,129 @@ +import logging +import os +import time +from typing import List + +import torch +import psutil + +#from eval import verification +#from partial_fc import PartialFC +#from torch2onnx import convert_onnx +from utils.utils_logging import AverageMeter + + +class CallBackVerification(object): + def __init__(self, frequent, rank, val_targets, rec_prefix, image_size=(112, 112)): + self.frequent: int = frequent + self.rank: int = rank + self.highest_acc: float = 0.0 + self.highest_acc_list: List[float] = [0.0] * len(val_targets) + self.ver_list: List[object] = [] + self.ver_name_list: List[str] = [] + if self.rank is 0: + self.init_dataset(val_targets=val_targets, data_dir=rec_prefix, image_size=image_size) + + def ver_test(self, backbone: torch.nn.Module, global_step: int): + results = [] + for i in range(len(self.ver_list)): + acc1, std1, acc2, std2, xnorm, embeddings_list = verification.test( + self.ver_list[i], backbone, 10, 10) + logging.info('[%s][%d]XNorm: %f' % (self.ver_name_list[i], global_step, xnorm)) + logging.info('[%s][%d]Accuracy-Flip: %1.5f+-%1.5f' % (self.ver_name_list[i], global_step, acc2, std2)) + if acc2 > self.highest_acc_list[i]: + self.highest_acc_list[i] = acc2 + logging.info( + '[%s][%d]Accuracy-Highest: %1.5f' % (self.ver_name_list[i], global_step, self.highest_acc_list[i])) + results.append(acc2) + + def init_dataset(self, val_targets, data_dir, image_size): + for name in val_targets: + path = os.path.join(data_dir, name + ".bin") + if os.path.exists(path): + data_set = verification.load_bin(path, image_size) + self.ver_list.append(data_set) + self.ver_name_list.append(name) + + def __call__(self, num_update, backbone: torch.nn.Module): + if self.rank is 0 and num_update > 0 and num_update % self.frequent == 0: + backbone.eval() + self.ver_test(backbone, num_update) + backbone.train() + + +class CallBackLogging(object): + def __init__(self, frequent, rank, total_step, batch_size, world_size, writer=None): + self.frequent: int = frequent + self.rank: int = rank + self.time_start = time.time() + self.total_step: int = total_step + self.batch_size: int = batch_size + self.world_size: int = world_size + self.writer = writer + + self.init = False + self.tic = 0 + + def __call__(self, global_step, loss, epoch, fp16, grad_scaler, opt): + if self.rank is 0 and global_step > 0 and global_step % self.frequent == 0: + if self.init: + try: + speed: float = self.frequent * self.batch_size / (time.time() - self.tic) + speed_total = speed * self.world_size + except ZeroDivisionError: + speed_total = float('inf') + + time_now = (time.time() - self.time_start) / 3600.0 + time_total = time_now / ((global_step + 1) / self.total_step) + time_for_end = time_total - time_now + lr = opt.param_groups[0]['lr'] + if self.writer is not None: + self.writer.add_scalar('time_for_end', time_for_end, global_step) + #self.writer.add_scalar('loss', loss.avg, global_step) + mem = psutil.virtual_memory() + mem_used = mem.used / (1024 ** 3) + loss_str = "" + for k,v in loss.items(): + if len(loss_str)!=0: + loss_str += " " + loss_str += "%s:%.4f"%(k, v.avg) + if fp16: + msg = "Speed %.2f samples/sec %s Epoch: %d Global Step: %d LR: %.8f " \ + "Fp16 Grad Scale: %2.f Required: %.1f hours MemUsed: %.3f" % ( + speed_total, loss_str, epoch, global_step, lr, grad_scaler.get_scale(), time_for_end, mem_used + ) + else: + msg = "Speed %.2f samples/sec %s Epoch: %d Global Step: %d LR: %.8f Required: %.1f hours MemUsed: %.3f" % ( + speed_total, loss_str, epoch, global_step, lr, time_for_end, mem_used + ) + logging.info(msg) + for k,v in loss.items(): + v.reset() + self.tic = time.time() + else: + self.init = True + self.tic = time.time() + + +class CallBackModelCheckpoint(object): + def __init__(self, rank, cfg): + self.rank = rank + self.output = cfg.output + #self.save_pfc = cfg.save_pfc + #self.save_onnx = cfg.save_onnx + self.save_opt = cfg.save_opt + + def __call__(self, epoch, backbone, opt_backbone): + if self.rank == 0: + path_module = os.path.join(self.output, "backbone_ep%04d.pth"%epoch) + if self.save_opt: + data = { + 'model': backbone.module.state_dict(), + 'optimizer': opt_backbone.state_dict(), + } + else: + data = backbone.module.state_dict() + torch.save(data, path_module) + logging.info("Pytorch Model Saved in '{}'".format(path_module)) + + diff --git a/reconstruction/jmlr/utils/utils_config.py b/reconstruction/jmlr/utils/utils_config.py new file mode 100644 index 0000000..74408b3 --- /dev/null +++ b/reconstruction/jmlr/utils/utils_config.py @@ -0,0 +1,27 @@ +import importlib +import os +import os.path as osp +import numpy as np + +def get_config(config_file): + assert config_file.startswith('configs/'), 'config file setting must start with configs/' + temp_config_name = osp.basename(config_file) + temp_module_name = osp.splitext(temp_config_name)[0] + #print('A:', config_file, temp_config_name, temp_module_name) + config1 = importlib.import_module("configs.base") + importlib.reload(config1) + cfg = config1.config + #print('B1:', cfg) + config2 = importlib.import_module("configs.%s"%temp_module_name) + importlib.reload(config2) + #reload(config2) + job_cfg = config2.config + #print('B2:', job_cfg) + cfg.update(job_cfg) + cfg.job_name = temp_module_name + #print('B:', cfg) + if cfg.output is None: + cfg.output = osp.join('work_dirs', temp_module_name) + #print('C:', cfg.output) + cfg.flipindex = np.load(cfg.flipindex_file) + return cfg diff --git a/reconstruction/jmlr/utils/utils_logging.py b/reconstruction/jmlr/utils/utils_logging.py new file mode 100644 index 0000000..7d58012 --- /dev/null +++ b/reconstruction/jmlr/utils/utils_logging.py @@ -0,0 +1,40 @@ +import logging +import os +import sys + + +class AverageMeter(object): + """Computes and stores the average and current value + """ + + def __init__(self): + self.val = None + self.avg = None + self.sum = None + self.count = None + self.reset() + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + + +def init_logging(log_root, rank, models_root): + if rank is 0: + log_root.setLevel(logging.INFO) + formatter = logging.Formatter("Training: %(asctime)s-%(message)s") + handler_file = logging.FileHandler(os.path.join(models_root, "training.log")) + handler_stream = logging.StreamHandler(sys.stdout) + handler_file.setFormatter(formatter) + handler_stream.setFormatter(formatter) + log_root.addHandler(handler_file) + log_root.addHandler(handler_stream) + log_root.info('rank_id: %d' % rank) diff --git a/reconstruction/jmlr/utils/utils_os.py b/reconstruction/jmlr/utils/utils_os.py new file mode 100644 index 0000000..e69de29