Files
insightface/parsing/dml_csr/utils/schp.py
QINGPING ZHENG fb988ae0ca Create schp.py
2022-03-23 00:21:32 +08:00

88 lines
2.7 KiB
Python

#!/usr/bin/env python
# -*- encoding: utf-8 -*-
"""
@Author : Qingping Zheng
@Contact : qingpingzheng2014@gmail.com
@File : schp.py
@Time : 10/01/21 00:00 PM
@Desc :
@License : Licensed under the Apache License, Version 2.0 (the "License");
@Copyright : Copyright 2022 The Authors. All Rights Reserved.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import torch
def moving_average(net1, net2, alpha=1):
for param1, param2 in zip(net1.parameters(), net2.parameters()):
param1.data *= (1.0 - alpha)
param1.data += param2.data * alpha
def _check_bn(module, flag):
classname = module.__class__.__name__
if classname.find('BatchNorm') != -1 or classname.find('InPlaceABNSync') != -1:
flag[0] = True
def check_bn(model):
flag = [False]
model.apply(lambda module: _check_bn(module, flag))
return flag[0]
def reset_bn(module):
classname = module.__class__.__name__
if classname.find('BatchNorm') != -1 or classname.find('InPlaceABNSync') != -1:
module.running_mean = torch.zeros_like(module.running_mean)
module.running_var = torch.ones_like(module.running_var)
def _get_momenta(module, momenta):
classname = module.__class__.__name__
if classname.find('BatchNorm') != -1 or classname.find('InPlaceABNSync') != -1:
momenta[module] = module.momentum
def _set_momenta(module, momenta):
classname = module.__class__.__name__
if classname.find('BatchNorm') != -1 or classname.find('InPlaceABNSync') != -1:
module.momentum = momenta[module]
def bn_re_estimate(loader, model):
if not check_bn(model):
print('No batch norm layer detected')
return
model.train()
momenta = {}
model.apply(reset_bn)
model.apply(lambda module: _get_momenta(module, momenta))
n = 0
for i_iter, batch in enumerate(loader):
# images, labels, edges, _ = batch
images = batch[0]
b = images.data.size(0)
momentum = b / (n + b)
for module in momenta.keys():
module.momentum = momentum
model(images)
n += b
model.apply(lambda module: _set_momenta(module, momenta))
def save_schp_checkpoint(states, is_best_parsing, output_dir, filename='schp_checkpoint.pth.tar'):
save_path = os.path.join(output_dir, filename)
# if os.path.exists(save_path):
# os.remove(save_path)
torch.save(states, save_path)
if is_best_parsing and 'state_dict' in states:
best_save_path = os.path.join(output_dir, 'model_parsing_best.pth.tar')
if os.path.exists(best_save_path):
os.remove(best_save_path)
torch.save(states, best_save_path)