Files
EasyFace/modelscope/utils/regress_test_utils.py
2023-03-02 11:17:26 +08:00

796 lines
30 KiB
Python

# Copyright (c) Alibaba, Inc. and its affiliates.
import contextlib
import hashlib
import json
import os
import pickle
import random
import re
import shutil
import tempfile
from collections import OrderedDict
from collections.abc import Mapping
from pathlib import Path
from types import FunctionType
from typing import Any, Dict, Union
import numpy as np
import torch
import torch.optim
from torch import nn
from .test_utils import compare_arguments_nested
class RegressTool:
"""This class is used to stop inference/training results from changing by some unaware affections by unittests.
Firstly, run a baseline test to create a result file, then changes can be observed between
the latest version and the baseline file.
"""
def __init__(self,
baseline: bool = None,
store_func: FunctionType = None,
load_func: FunctionType = None):
"""A func to store the baseline file and a func to load the baseline file.
"""
self.baseline = baseline
self.store_func = store_func
self.load_func = load_func
print(f'Current working dir is: {Path.cwd()}')
def store(self, local, remote):
if self.store_func is not None:
self.store_func(local, remote)
else:
path = os.path.abspath(
os.path.join(Path.cwd(), 'data', 'test', 'regression'))
os.makedirs(path, exist_ok=True)
shutil.copy(local, os.path.join(path, remote))
def load(self, local, remote):
if self.load_func is not None:
self.load_func(local, remote)
else:
path = os.path.abspath(
os.path.join(Path.cwd(), 'data', 'test', 'regression'))
baseline = os.path.join(path, remote)
if not os.path.exists(baseline):
raise ValueError(f'base line file {baseline} not exist')
print(
f'local file found:{baseline}, md5:{hashlib.md5(open(baseline,"rb").read()).hexdigest()}'
)
if os.path.exists(local):
os.remove(local)
os.symlink(baseline, local, target_is_directory=False)
@contextlib.contextmanager
def monitor_module_single_forward(self,
module: nn.Module,
file_name: str,
compare_fn=None,
compare_model_output=True,
**kwargs):
"""Monitor a pytorch module in a single forward.
Args:
module: A torch module
file_name: The file_name to store or load file
compare_fn: A custom fn used to compare the results manually.
compare_model_output: Only compare the input module's output, skip all other tensors
>>> def compare_fn(v1, v2, key, type):
>>> return None
v1 is the baseline value
v2 is the value of current version
key is the key of submodules
type is in one of 'input', 'output'
kwargs:
atol: The absolute gap between two np arrays.
rtol: The relative gap between two np arrays.
"""
baseline = os.getenv('REGRESSION_BASELINE')
if baseline is None or self.baseline is None:
yield
return
baseline = self.baseline
io_json = {}
absolute_path = f'./{file_name}.bin'
if not isinstance(module, nn.Module):
assert hasattr(module, 'model')
module = module.model
hack_forward(module, file_name, io_json)
intercept_module(module, io_json)
yield
hack_forward(module, None, None, restore=True)
intercept_module(module, None, restore=True)
if baseline:
with open(absolute_path, 'wb') as f:
pickle.dump(io_json, f)
self.store(absolute_path, f'{file_name}.bin')
os.remove(absolute_path)
else:
name = os.path.basename(absolute_path)
baseline = os.path.join(tempfile.gettempdir(), name)
self.load(baseline, name)
with open(baseline, 'rb') as f:
base = pickle.load(f)
class SafeNumpyEncoder(json.JSONEncoder):
def parse_default(self, obj):
if isinstance(obj, np.ndarray):
return obj.tolist()
if isinstance(obj, np.floating):
return float(obj)
if isinstance(obj, np.integer):
return int(obj)
return json.JSONEncoder.default(self, obj)
def default(self, obj):
try:
return self.default(obj)
except Exception:
print(
f'Type {obj.__class__} cannot be serialized and printed'
)
return None
if compare_model_output:
print(
'Ignore inner modules, only the output of the model will be verified.'
)
base = {
key: value
for key, value in base.items() if key == file_name
}
for key, value in base.items():
value['input'] = {'args': None, 'kwargs': None}
io_json = {
key: value
for key, value in io_json.items() if key == file_name
}
for key, value in io_json.items():
value['input'] = {'args': None, 'kwargs': None}
print(f'baseline: {json.dumps(base, cls=SafeNumpyEncoder)}')
print(f'latest : {json.dumps(io_json, cls=SafeNumpyEncoder)}')
if not compare_io_and_print(base, io_json, compare_fn, **kwargs):
raise ValueError('Result not match!')
@contextlib.contextmanager
def monitor_module_train(self,
trainer: Union[Dict, Any],
file_name,
level='config',
compare_fn=None,
ignore_keys=None,
compare_random=True,
reset_dropout=True,
lazy_stop_callback=None,
**kwargs):
"""Monitor a pytorch module's backward data and cfg data within a step of the optimizer.
This is usually useful when you try to change some dangerous code
which has the risk of affecting the training loop.
Args:
trainer: A dict or an object contains the model/optimizer/lr_scheduler
file_name: The file_name to store or load file
level: The regression level.
'strict' for matching every single tensor.
Please make sure the parameters of head are fixed
and the drop-out rate is zero.
'config' for matching the initial config, like cfg file, optimizer param_groups,
lr_scheduler params and the random seed.
'metric' for compare the best metrics in the evaluation loop.
compare_fn: A custom fn used to compare the results manually.
ignore_keys: The keys to ignore of the named_parameters.
compare_random: If to compare random setttings, default True.
reset_dropout: Reset all dropout modules to 0.0.
lazy_stop_callback: A callback passed in, when the moniting is over, this callback will be called.
kwargs:
atol: The absolute gap between two np arrays.
rtol: The relative gap between two np arrays.
>>> def compare_fn(v1, v2, key, type):
>>> return None
v1 is the baseline value
v2 is the value of current version
key is the key of modules/parameters
type is in one of 'input', 'output', 'backward', 'optimizer', 'lr_scheduler', 'cfg', 'state'
"""
baseline = os.getenv('REGRESSION_BASELINE')
if baseline is None or self.baseline is None:
yield
return
baseline = self.baseline
io_json = {}
bw_json = {}
absolute_path = f'./{file_name}.bin'
if level == 'strict':
print(
"[Important] The level of regression is 'strict', please make sure your model's parameters are "
'fixed and all drop-out rates have been set to zero.')
assert hasattr(
trainer, 'model') or 'model' in trainer, 'model must be in trainer'
module = trainer['model'] if isinstance(trainer,
dict) else trainer.model
if not isinstance(module, nn.Module):
assert hasattr(module, 'model')
module = module.model
assert hasattr(
trainer, 'optimizer'
) or 'optimizer' in trainer, 'optimizer must be in trainer'
assert hasattr(
trainer, 'lr_scheduler'
) or 'lr_scheduler' in trainer, 'lr_scheduler must be in trainer'
optimizer: torch.optim.Optimizer = trainer['optimizer'] if isinstance(
trainer, dict) else trainer.optimizer
lr_scheduler: torch.optim.lr_scheduler._LRScheduler = trainer['lr_scheduler'] if isinstance(trainer, dict) \
else trainer.lr_scheduler
torch_state = numpify_tensor_nested(torch.get_rng_state())
np_state = np.random.get_state()
random_seed = random.getstate()
seed = trainer._seed if hasattr(
trainer,
'_seed') else trainer.seed if hasattr(trainer, 'seed') else None
if reset_dropout:
with torch.no_grad():
def reinit_dropout(_module):
for name, submodule in _module.named_children():
if isinstance(submodule, torch.nn.Dropout):
setattr(_module, name, torch.nn.Dropout(0.))
else:
reinit_dropout(submodule)
reinit_dropout(module)
if level == 'strict':
hack_forward(module, file_name, io_json)
intercept_module(module, io_json)
hack_backward(module,
optimizer,
bw_json,
lazy_stop_callback=lazy_stop_callback)
yield
hack_backward(module, optimizer, None, restore=True)
if level == 'strict':
hack_forward(module, None, None, restore=True)
intercept_module(module, None, restore=True)
optimizer_dict = optimizer.state_dict()
optimizer_dict.pop('state', None)
summary = {
'forward': io_json,
'backward': bw_json,
'optimizer': {
'type': optimizer.__class__.__name__,
'defaults': optimizer.defaults,
'state_dict': optimizer_dict
},
'lr_scheduler': {
'type': lr_scheduler.__class__.__name__,
'state_dict': lr_scheduler.state_dict()
},
'cfg': trainer.cfg.to_dict() if hasattr(trainer, 'cfg') else None,
'state': {
'torch_state': torch_state,
'np_state': np_state,
'random_seed': random_seed,
'seed': seed,
}
}
if baseline:
with open(absolute_path, 'wb') as f:
pickle.dump(summary, f)
self.store(absolute_path, f'{file_name}.bin')
os.remove(absolute_path)
else:
name = os.path.basename(absolute_path)
baseline = os.path.join(tempfile.gettempdir(), name)
self.load(baseline, name)
with open(baseline, 'rb') as f:
baseline_json = pickle.load(f)
if level == 'strict' and not compare_io_and_print(
baseline_json['forward'], io_json, compare_fn, **kwargs):
raise RuntimeError('Forward not match!')
if not compare_backward_and_print(baseline_json['backward'],
bw_json,
compare_fn=compare_fn,
ignore_keys=ignore_keys,
level=level,
**kwargs):
raise RuntimeError('Backward not match!')
cfg_opt1 = {
'optimizer': baseline_json['optimizer'],
'lr_scheduler': baseline_json['lr_scheduler'],
'cfg': baseline_json['cfg'],
'state': None if not compare_random else baseline_json['state']
}
cfg_opt2 = {
'optimizer': summary['optimizer'],
'lr_scheduler': summary['lr_scheduler'],
'cfg': summary['cfg'],
'state': None if not compare_random else summary['state']
}
if not compare_cfg_and_optimizers(cfg_opt1, cfg_opt2, compare_fn,
**kwargs):
raise RuntimeError('Cfg or optimizers not match!')
class MsRegressTool(RegressTool):
class EarlyStopError(Exception):
pass
@contextlib.contextmanager
def monitor_ms_train(self,
trainer,
file_name,
level='config',
compare_fn=None,
ignore_keys=None,
compare_random=True,
lazy_stop_callback=None,
**kwargs):
if lazy_stop_callback is None:
def lazy_stop_callback():
class EarlyStopHook:
PRIORITY = 90
def before_run(self, trainer):
pass
def after_run(self, trainer):
pass
def before_epoch(self, trainer):
pass
def after_epoch(self, trainer):
pass
def before_iter(self, trainer):
pass
def before_train_epoch(self, trainer):
self.before_epoch(trainer)
def before_val_epoch(self, trainer):
self.before_epoch(trainer)
def after_train_epoch(self, trainer):
self.after_epoch(trainer)
def after_val_epoch(self, trainer):
self.after_epoch(trainer)
def before_train_iter(self, trainer):
self.before_iter(trainer)
def before_val_iter(self, trainer):
self.before_iter(trainer)
def after_train_iter(self, trainer):
self.after_iter(trainer)
def after_val_iter(self, trainer):
self.after_iter(trainer)
def every_n_epochs(self, trainer, n):
return (trainer.epoch + 1) % n == 0 if n > 0 else False
def every_n_inner_iters(self, runner, n):
return (runner.inner_iter +
1) % n == 0 if n > 0 else False
def every_n_iters(self, trainer, n):
return (trainer.iter + 1) % n == 0 if n > 0 else False
def end_of_epoch(self, trainer):
return trainer.inner_iter + 1 == trainer.iters_per_epoch
def is_last_epoch(self, trainer):
return trainer.epoch + 1 == trainer.max_epochs
def is_last_iter(self, trainer):
return trainer.iter + 1 == trainer.max_iters
def get_triggered_stages(self):
return []
def state_dict(self):
return {}
def load_state_dict(self, state_dict):
pass
def after_iter(self, trainer):
raise MsRegressTool.EarlyStopError('Test finished.')
trainer.register_hook(EarlyStopHook())
def _train_loop(trainer, *args_train, **kwargs_train):
with self.monitor_module_train(
trainer,
file_name,
level,
compare_fn=compare_fn,
ignore_keys=ignore_keys,
compare_random=compare_random,
lazy_stop_callback=lazy_stop_callback,
**kwargs):
try:
return trainer.train_loop_origin(*args_train,
**kwargs_train)
except MsRegressTool.EarlyStopError:
pass
trainer.train_loop_origin, trainer.train_loop = \
trainer.train_loop, type(trainer.train_loop)(_train_loop, trainer)
yield
def compare_module(module1: nn.Module, module2: nn.Module):
for p1, p2 in zip(module1.parameters(), module2.parameters()):
if p1.data.ne(p2.data).sum() > 0:
return False
return True
def numpify_tensor_nested(tensors, reduction=None, clip_value=10000):
try:
from modelscope.outputs import ModelOutputBase
except ImportError:
ModelOutputBase = dict
"Numpify `tensors` (even if it's a nested list/tuple of tensors)."
if isinstance(tensors, (Mapping, ModelOutputBase)):
return OrderedDict({
k: numpify_tensor_nested(t, reduction, clip_value)
for k, t in tensors.items()
})
if isinstance(tensors, list):
return list(
numpify_tensor_nested(t, reduction, clip_value) for t in tensors)
if isinstance(tensors, tuple):
return tuple(
numpify_tensor_nested(t, reduction, clip_value) for t in tensors)
if isinstance(tensors, torch.Tensor):
t: np.ndarray = tensors.cpu().numpy()
if clip_value is not None:
t = np.where(t > clip_value, clip_value, t)
t = np.where(t < -clip_value, -clip_value, t)
if reduction == 'sum':
return t.sum(dtype=np.float)
elif reduction == 'mean':
return t.mean(dtype=np.float)
return t
return tensors
def detach_tensor_nested(tensors):
try:
from modelscope.outputs import ModelOutputBase
except ImportError:
ModelOutputBase = dict
"Detach `tensors` (even if it's a nested list/tuple of tensors)."
if isinstance(tensors, (Mapping, ModelOutputBase)):
return OrderedDict(
{k: detach_tensor_nested(t)
for k, t in tensors.items()})
if isinstance(tensors, list):
return list(detach_tensor_nested(t) for t in tensors)
if isinstance(tensors, tuple):
return tuple(detach_tensor_nested(t) for t in tensors)
if isinstance(tensors, torch.Tensor):
return tensors.detach()
return tensors
def hack_forward(module: nn.Module,
name,
io_json,
restore=False,
keep_tensors=False):
def _forward(self, *args, **kwargs):
ret = self.forward_origin(*args, **kwargs)
if keep_tensors:
args = numpify_tensor_nested(detach_tensor_nested(args))
kwargs = numpify_tensor_nested(detach_tensor_nested(kwargs))
output = numpify_tensor_nested(detach_tensor_nested(ret))
else:
args = {
'sum':
numpify_tensor_nested(detach_tensor_nested(args),
reduction='sum'),
'mean':
numpify_tensor_nested(detach_tensor_nested(args),
reduction='mean'),
}
kwargs = {
'sum':
numpify_tensor_nested(detach_tensor_nested(kwargs),
reduction='sum'),
'mean':
numpify_tensor_nested(detach_tensor_nested(kwargs),
reduction='mean'),
}
output = {
'sum':
numpify_tensor_nested(detach_tensor_nested(ret),
reduction='sum'),
'mean':
numpify_tensor_nested(detach_tensor_nested(ret),
reduction='mean'),
}
io_json[name] = {
'input': {
'args': args,
'kwargs': kwargs,
},
'output': output,
}
return ret
if not restore and not hasattr(module, 'forward_origin'):
module.forward_origin, module.forward = module.forward, type(
module.forward)(_forward, module)
if restore and hasattr(module, 'forward_origin'):
module.forward = module.forward_origin
del module.forward_origin
def hack_backward(module: nn.Module,
optimizer,
io_json,
restore=False,
lazy_stop_callback=None):
def _step(self, *args, **kwargs):
for name, param in module.named_parameters():
io_json[name] = {
'data': {
'sum':
numpify_tensor_nested(detach_tensor_nested(param.data),
reduction='sum'),
'mean':
numpify_tensor_nested(detach_tensor_nested(param.data),
reduction='mean'),
},
'grad': {
'sum':
numpify_tensor_nested(detach_tensor_nested(param.grad),
reduction='sum'),
'mean':
numpify_tensor_nested(detach_tensor_nested(param.grad),
reduction='mean'),
}
}
ret = self.step_origin(*args, **kwargs)
for name, param in module.named_parameters():
io_json[name]['data_after'] = {
'sum':
numpify_tensor_nested(detach_tensor_nested(param.data),
reduction='sum'),
'mean':
numpify_tensor_nested(detach_tensor_nested(param.data),
reduction='mean'),
}
if lazy_stop_callback is not None:
lazy_stop_callback()
return ret
if not restore and not hasattr(optimizer, 'step_origin'):
optimizer.step_origin, optimizer.step = optimizer.step, type(
optimizer.state_dict)(_step, optimizer)
if restore and hasattr(optimizer, 'step_origin'):
optimizer.step = optimizer.step_origin
del optimizer.step_origin
def intercept_module(module: nn.Module,
io_json,
parent_name=None,
restore=False):
for name, module in module.named_children():
full_name = parent_name + '.' + name if parent_name is not None else name
hack_forward(module, full_name, io_json, restore)
intercept_module(module, io_json, full_name, restore)
def compare_io_and_print(baseline_json, io_json, compare_fn=None, **kwargs):
if compare_fn is None:
def compare_fn(*args, **kwargs):
return None
keys1 = set(baseline_json.keys())
keys2 = set(io_json.keys())
added = keys1 - keys2
removed = keys2 - keys1
print(f'unmatched keys: {added}, {removed}')
shared_keys = keys1.intersection(keys2)
match = True
for key in shared_keys:
v1 = baseline_json[key]
v2 = io_json[key]
v1input = numpify_tensor_nested(v1['input'])
v2input = numpify_tensor_nested(v2['input'])
res = compare_fn(v1input, v2input, key, 'input')
if res is not None:
print(
f'input of {key} compared with user compare_fn with result:{res}\n'
)
match = match and res
else:
match = compare_arguments_nested(
f'unmatched module {key} input args', v1input['args'],
v2input['args'], **kwargs) and match
match = compare_arguments_nested(
f'unmatched module {key} input kwargs', v1input['kwargs'],
v2input['kwargs'], **kwargs) and match
v1output = numpify_tensor_nested(v1['output'])
v2output = numpify_tensor_nested(v2['output'])
res = compare_fn(v1output, v2output, key, 'output')
if res is not None:
print(
f'output of {key} compared with user compare_fn with result:{res}\n'
)
match = match and res
else:
match = compare_arguments_nested(f'unmatched module {key} outputs',
arg1=v1output,
arg2=v2output,
**kwargs) and match
return match
def compare_backward_and_print(baseline_json,
bw_json,
level,
ignore_keys=None,
compare_fn=None,
**kwargs):
if compare_fn is None:
def compare_fn(*args, **kwargs):
return None
keys1 = set(baseline_json.keys())
keys2 = set(bw_json.keys())
added = keys1 - keys2
removed = keys2 - keys1
print(f'unmatched backward keys: {added}, {removed}')
shared_keys = keys1.intersection(keys2)
match = True
for key in shared_keys:
if ignore_keys is not None and key in ignore_keys:
continue
res = compare_fn(baseline_json[key], bw_json[key], key, 'backward')
if res is not None:
print(f'backward data of {key} compared with '
f'user compare_fn with result:{res}\n')
match = match and res
else:
data1, grad1, data_after1 = baseline_json[key][
'data'], baseline_json[key]['grad'], baseline_json[key][
'data_after']
data2, grad2, data_after2 = bw_json[key]['data'], bw_json[key][
'grad'], bw_json[key]['data_after']
match = compare_arguments_nested(
f'unmatched module {key} tensor data',
arg1=data1,
arg2=data2,
**kwargs) and match
if level == 'strict':
match = compare_arguments_nested(
f'unmatched module {key} grad data',
arg1=grad1,
arg2=grad2,
**kwargs) and match
match = compare_arguments_nested(
f'unmatched module {key} data after step', data_after1,
data_after2, **kwargs) and match
return match
def compare_cfg_and_optimizers(baseline_json,
cfg_json,
compare_fn=None,
**kwargs):
if compare_fn is None:
def compare_fn(*args, **kwargs):
return None
optimizer1, lr_scheduler1, cfg1, state1 = baseline_json[
'optimizer'], baseline_json['lr_scheduler'], baseline_json[
'cfg'], baseline_json['state']
optimizer2, lr_scheduler2, cfg2, state2 = cfg_json['optimizer'], cfg_json[
'lr_scheduler'], cfg_json['cfg'], baseline_json['state']
match = True
res = compare_fn(optimizer1, optimizer2, None, 'optimizer')
if res is not None:
print(f'optimizer compared with user compare_fn with result:{res}\n')
match = match and res
else:
if optimizer1['type'] != optimizer2['type']:
print(
f"Optimizer type not equal:{optimizer1['type']} and {optimizer2['type']}"
)
match = compare_arguments_nested(
'unmatched optimizer defaults', optimizer1['defaults'],
optimizer2['defaults'], **kwargs) and match
match = compare_arguments_nested(
'unmatched optimizer state_dict', optimizer1['state_dict'],
optimizer2['state_dict'], **kwargs) and match
res = compare_fn(lr_scheduler1, lr_scheduler2, None, 'lr_scheduler')
if res is not None:
print(
f'lr_scheduler compared with user compare_fn with result:{res}\n')
match = match and res
else:
if lr_scheduler1['type'] != lr_scheduler2['type']:
print(
f"Optimizer type not equal:{lr_scheduler1['type']} and {lr_scheduler2['type']}"
)
match = compare_arguments_nested(
'unmatched lr_scheduler state_dict', lr_scheduler1['state_dict'],
lr_scheduler2['state_dict'], **kwargs) and match
res = compare_fn(cfg1, cfg2, None, 'cfg')
if res is not None:
print(f'cfg compared with user compare_fn with result:{res}\n')
match = match and res
else:
match = compare_arguments_nested(
'unmatched cfg', arg1=cfg1, arg2=cfg2, **kwargs) and match
res = compare_fn(state1, state2, None, 'state')
if res is not None:
print(
f'random state compared with user compare_fn with result:{res}\n')
match = match and res
else:
match = compare_arguments_nested('unmatched random state', state1,
state2, **kwargs) and match
return match
class IgnoreKeyFn:
def __init__(self, keys):
if isinstance(keys, str):
keys = [keys]
self.keys = keys if isinstance(keys, list) else []
def __call__(self, v1output, v2output, key, type):
for _key in self.keys:
pattern = re.compile(_key)
if key is not None and pattern.fullmatch(key):
return True
return None