mirror of
https://gitcode.com/gh_mirrors/eas/EasyFace.git
synced 2025-12-30 13:02:29 +00:00
796 lines
30 KiB
Python
796 lines
30 KiB
Python
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
|
|
import contextlib
|
|
import hashlib
|
|
import json
|
|
import os
|
|
import pickle
|
|
import random
|
|
import re
|
|
import shutil
|
|
import tempfile
|
|
from collections import OrderedDict
|
|
from collections.abc import Mapping
|
|
from pathlib import Path
|
|
from types import FunctionType
|
|
from typing import Any, Dict, Union
|
|
|
|
import numpy as np
|
|
import torch
|
|
import torch.optim
|
|
from torch import nn
|
|
|
|
from .test_utils import compare_arguments_nested
|
|
|
|
|
|
class RegressTool:
|
|
"""This class is used to stop inference/training results from changing by some unaware affections by unittests.
|
|
|
|
Firstly, run a baseline test to create a result file, then changes can be observed between
|
|
the latest version and the baseline file.
|
|
"""
|
|
def __init__(self,
|
|
baseline: bool = None,
|
|
store_func: FunctionType = None,
|
|
load_func: FunctionType = None):
|
|
"""A func to store the baseline file and a func to load the baseline file.
|
|
"""
|
|
self.baseline = baseline
|
|
self.store_func = store_func
|
|
self.load_func = load_func
|
|
print(f'Current working dir is: {Path.cwd()}')
|
|
|
|
def store(self, local, remote):
|
|
if self.store_func is not None:
|
|
self.store_func(local, remote)
|
|
else:
|
|
path = os.path.abspath(
|
|
os.path.join(Path.cwd(), 'data', 'test', 'regression'))
|
|
os.makedirs(path, exist_ok=True)
|
|
shutil.copy(local, os.path.join(path, remote))
|
|
|
|
def load(self, local, remote):
|
|
if self.load_func is not None:
|
|
self.load_func(local, remote)
|
|
else:
|
|
path = os.path.abspath(
|
|
os.path.join(Path.cwd(), 'data', 'test', 'regression'))
|
|
baseline = os.path.join(path, remote)
|
|
if not os.path.exists(baseline):
|
|
raise ValueError(f'base line file {baseline} not exist')
|
|
print(
|
|
f'local file found:{baseline}, md5:{hashlib.md5(open(baseline,"rb").read()).hexdigest()}'
|
|
)
|
|
if os.path.exists(local):
|
|
os.remove(local)
|
|
os.symlink(baseline, local, target_is_directory=False)
|
|
|
|
@contextlib.contextmanager
|
|
def monitor_module_single_forward(self,
|
|
module: nn.Module,
|
|
file_name: str,
|
|
compare_fn=None,
|
|
compare_model_output=True,
|
|
**kwargs):
|
|
"""Monitor a pytorch module in a single forward.
|
|
|
|
Args:
|
|
module: A torch module
|
|
file_name: The file_name to store or load file
|
|
compare_fn: A custom fn used to compare the results manually.
|
|
compare_model_output: Only compare the input module's output, skip all other tensors
|
|
|
|
>>> def compare_fn(v1, v2, key, type):
|
|
>>> return None
|
|
|
|
v1 is the baseline value
|
|
v2 is the value of current version
|
|
key is the key of submodules
|
|
type is in one of 'input', 'output'
|
|
|
|
kwargs:
|
|
atol: The absolute gap between two np arrays.
|
|
rtol: The relative gap between two np arrays.
|
|
"""
|
|
baseline = os.getenv('REGRESSION_BASELINE')
|
|
if baseline is None or self.baseline is None:
|
|
yield
|
|
return
|
|
|
|
baseline = self.baseline
|
|
io_json = {}
|
|
absolute_path = f'./{file_name}.bin'
|
|
if not isinstance(module, nn.Module):
|
|
assert hasattr(module, 'model')
|
|
module = module.model
|
|
|
|
hack_forward(module, file_name, io_json)
|
|
intercept_module(module, io_json)
|
|
yield
|
|
hack_forward(module, None, None, restore=True)
|
|
intercept_module(module, None, restore=True)
|
|
if baseline:
|
|
with open(absolute_path, 'wb') as f:
|
|
pickle.dump(io_json, f)
|
|
self.store(absolute_path, f'{file_name}.bin')
|
|
os.remove(absolute_path)
|
|
else:
|
|
name = os.path.basename(absolute_path)
|
|
baseline = os.path.join(tempfile.gettempdir(), name)
|
|
self.load(baseline, name)
|
|
with open(baseline, 'rb') as f:
|
|
base = pickle.load(f)
|
|
|
|
class SafeNumpyEncoder(json.JSONEncoder):
|
|
def parse_default(self, obj):
|
|
if isinstance(obj, np.ndarray):
|
|
return obj.tolist()
|
|
|
|
if isinstance(obj, np.floating):
|
|
return float(obj)
|
|
|
|
if isinstance(obj, np.integer):
|
|
return int(obj)
|
|
|
|
return json.JSONEncoder.default(self, obj)
|
|
|
|
def default(self, obj):
|
|
try:
|
|
return self.default(obj)
|
|
except Exception:
|
|
print(
|
|
f'Type {obj.__class__} cannot be serialized and printed'
|
|
)
|
|
return None
|
|
|
|
if compare_model_output:
|
|
print(
|
|
'Ignore inner modules, only the output of the model will be verified.'
|
|
)
|
|
base = {
|
|
key: value
|
|
for key, value in base.items() if key == file_name
|
|
}
|
|
for key, value in base.items():
|
|
value['input'] = {'args': None, 'kwargs': None}
|
|
io_json = {
|
|
key: value
|
|
for key, value in io_json.items() if key == file_name
|
|
}
|
|
for key, value in io_json.items():
|
|
value['input'] = {'args': None, 'kwargs': None}
|
|
|
|
print(f'baseline: {json.dumps(base, cls=SafeNumpyEncoder)}')
|
|
print(f'latest : {json.dumps(io_json, cls=SafeNumpyEncoder)}')
|
|
if not compare_io_and_print(base, io_json, compare_fn, **kwargs):
|
|
raise ValueError('Result not match!')
|
|
|
|
@contextlib.contextmanager
|
|
def monitor_module_train(self,
|
|
trainer: Union[Dict, Any],
|
|
file_name,
|
|
level='config',
|
|
compare_fn=None,
|
|
ignore_keys=None,
|
|
compare_random=True,
|
|
reset_dropout=True,
|
|
lazy_stop_callback=None,
|
|
**kwargs):
|
|
"""Monitor a pytorch module's backward data and cfg data within a step of the optimizer.
|
|
|
|
This is usually useful when you try to change some dangerous code
|
|
which has the risk of affecting the training loop.
|
|
|
|
Args:
|
|
trainer: A dict or an object contains the model/optimizer/lr_scheduler
|
|
file_name: The file_name to store or load file
|
|
level: The regression level.
|
|
'strict' for matching every single tensor.
|
|
Please make sure the parameters of head are fixed
|
|
and the drop-out rate is zero.
|
|
'config' for matching the initial config, like cfg file, optimizer param_groups,
|
|
lr_scheduler params and the random seed.
|
|
'metric' for compare the best metrics in the evaluation loop.
|
|
compare_fn: A custom fn used to compare the results manually.
|
|
ignore_keys: The keys to ignore of the named_parameters.
|
|
compare_random: If to compare random setttings, default True.
|
|
reset_dropout: Reset all dropout modules to 0.0.
|
|
lazy_stop_callback: A callback passed in, when the moniting is over, this callback will be called.
|
|
kwargs:
|
|
atol: The absolute gap between two np arrays.
|
|
rtol: The relative gap between two np arrays.
|
|
|
|
>>> def compare_fn(v1, v2, key, type):
|
|
>>> return None
|
|
|
|
v1 is the baseline value
|
|
v2 is the value of current version
|
|
key is the key of modules/parameters
|
|
type is in one of 'input', 'output', 'backward', 'optimizer', 'lr_scheduler', 'cfg', 'state'
|
|
"""
|
|
baseline = os.getenv('REGRESSION_BASELINE')
|
|
if baseline is None or self.baseline is None:
|
|
yield
|
|
return
|
|
|
|
baseline = self.baseline
|
|
|
|
io_json = {}
|
|
bw_json = {}
|
|
absolute_path = f'./{file_name}.bin'
|
|
|
|
if level == 'strict':
|
|
print(
|
|
"[Important] The level of regression is 'strict', please make sure your model's parameters are "
|
|
'fixed and all drop-out rates have been set to zero.')
|
|
|
|
assert hasattr(
|
|
trainer, 'model') or 'model' in trainer, 'model must be in trainer'
|
|
module = trainer['model'] if isinstance(trainer,
|
|
dict) else trainer.model
|
|
if not isinstance(module, nn.Module):
|
|
assert hasattr(module, 'model')
|
|
module = module.model
|
|
|
|
assert hasattr(
|
|
trainer, 'optimizer'
|
|
) or 'optimizer' in trainer, 'optimizer must be in trainer'
|
|
assert hasattr(
|
|
trainer, 'lr_scheduler'
|
|
) or 'lr_scheduler' in trainer, 'lr_scheduler must be in trainer'
|
|
optimizer: torch.optim.Optimizer = trainer['optimizer'] if isinstance(
|
|
trainer, dict) else trainer.optimizer
|
|
lr_scheduler: torch.optim.lr_scheduler._LRScheduler = trainer['lr_scheduler'] if isinstance(trainer, dict) \
|
|
else trainer.lr_scheduler
|
|
torch_state = numpify_tensor_nested(torch.get_rng_state())
|
|
np_state = np.random.get_state()
|
|
random_seed = random.getstate()
|
|
seed = trainer._seed if hasattr(
|
|
trainer,
|
|
'_seed') else trainer.seed if hasattr(trainer, 'seed') else None
|
|
|
|
if reset_dropout:
|
|
with torch.no_grad():
|
|
|
|
def reinit_dropout(_module):
|
|
for name, submodule in _module.named_children():
|
|
if isinstance(submodule, torch.nn.Dropout):
|
|
setattr(_module, name, torch.nn.Dropout(0.))
|
|
else:
|
|
reinit_dropout(submodule)
|
|
|
|
reinit_dropout(module)
|
|
|
|
if level == 'strict':
|
|
hack_forward(module, file_name, io_json)
|
|
intercept_module(module, io_json)
|
|
hack_backward(module,
|
|
optimizer,
|
|
bw_json,
|
|
lazy_stop_callback=lazy_stop_callback)
|
|
yield
|
|
hack_backward(module, optimizer, None, restore=True)
|
|
if level == 'strict':
|
|
hack_forward(module, None, None, restore=True)
|
|
intercept_module(module, None, restore=True)
|
|
|
|
optimizer_dict = optimizer.state_dict()
|
|
optimizer_dict.pop('state', None)
|
|
summary = {
|
|
'forward': io_json,
|
|
'backward': bw_json,
|
|
'optimizer': {
|
|
'type': optimizer.__class__.__name__,
|
|
'defaults': optimizer.defaults,
|
|
'state_dict': optimizer_dict
|
|
},
|
|
'lr_scheduler': {
|
|
'type': lr_scheduler.__class__.__name__,
|
|
'state_dict': lr_scheduler.state_dict()
|
|
},
|
|
'cfg': trainer.cfg.to_dict() if hasattr(trainer, 'cfg') else None,
|
|
'state': {
|
|
'torch_state': torch_state,
|
|
'np_state': np_state,
|
|
'random_seed': random_seed,
|
|
'seed': seed,
|
|
}
|
|
}
|
|
|
|
if baseline:
|
|
with open(absolute_path, 'wb') as f:
|
|
pickle.dump(summary, f)
|
|
self.store(absolute_path, f'{file_name}.bin')
|
|
os.remove(absolute_path)
|
|
else:
|
|
name = os.path.basename(absolute_path)
|
|
baseline = os.path.join(tempfile.gettempdir(), name)
|
|
self.load(baseline, name)
|
|
with open(baseline, 'rb') as f:
|
|
baseline_json = pickle.load(f)
|
|
|
|
if level == 'strict' and not compare_io_and_print(
|
|
baseline_json['forward'], io_json, compare_fn, **kwargs):
|
|
raise RuntimeError('Forward not match!')
|
|
if not compare_backward_and_print(baseline_json['backward'],
|
|
bw_json,
|
|
compare_fn=compare_fn,
|
|
ignore_keys=ignore_keys,
|
|
level=level,
|
|
**kwargs):
|
|
raise RuntimeError('Backward not match!')
|
|
cfg_opt1 = {
|
|
'optimizer': baseline_json['optimizer'],
|
|
'lr_scheduler': baseline_json['lr_scheduler'],
|
|
'cfg': baseline_json['cfg'],
|
|
'state': None if not compare_random else baseline_json['state']
|
|
}
|
|
cfg_opt2 = {
|
|
'optimizer': summary['optimizer'],
|
|
'lr_scheduler': summary['lr_scheduler'],
|
|
'cfg': summary['cfg'],
|
|
'state': None if not compare_random else summary['state']
|
|
}
|
|
if not compare_cfg_and_optimizers(cfg_opt1, cfg_opt2, compare_fn,
|
|
**kwargs):
|
|
raise RuntimeError('Cfg or optimizers not match!')
|
|
|
|
|
|
class MsRegressTool(RegressTool):
|
|
class EarlyStopError(Exception):
|
|
pass
|
|
|
|
@contextlib.contextmanager
|
|
def monitor_ms_train(self,
|
|
trainer,
|
|
file_name,
|
|
level='config',
|
|
compare_fn=None,
|
|
ignore_keys=None,
|
|
compare_random=True,
|
|
lazy_stop_callback=None,
|
|
**kwargs):
|
|
|
|
if lazy_stop_callback is None:
|
|
|
|
def lazy_stop_callback():
|
|
class EarlyStopHook:
|
|
PRIORITY = 90
|
|
|
|
def before_run(self, trainer):
|
|
pass
|
|
|
|
def after_run(self, trainer):
|
|
pass
|
|
|
|
def before_epoch(self, trainer):
|
|
pass
|
|
|
|
def after_epoch(self, trainer):
|
|
pass
|
|
|
|
def before_iter(self, trainer):
|
|
pass
|
|
|
|
def before_train_epoch(self, trainer):
|
|
self.before_epoch(trainer)
|
|
|
|
def before_val_epoch(self, trainer):
|
|
self.before_epoch(trainer)
|
|
|
|
def after_train_epoch(self, trainer):
|
|
self.after_epoch(trainer)
|
|
|
|
def after_val_epoch(self, trainer):
|
|
self.after_epoch(trainer)
|
|
|
|
def before_train_iter(self, trainer):
|
|
self.before_iter(trainer)
|
|
|
|
def before_val_iter(self, trainer):
|
|
self.before_iter(trainer)
|
|
|
|
def after_train_iter(self, trainer):
|
|
self.after_iter(trainer)
|
|
|
|
def after_val_iter(self, trainer):
|
|
self.after_iter(trainer)
|
|
|
|
def every_n_epochs(self, trainer, n):
|
|
return (trainer.epoch + 1) % n == 0 if n > 0 else False
|
|
|
|
def every_n_inner_iters(self, runner, n):
|
|
return (runner.inner_iter +
|
|
1) % n == 0 if n > 0 else False
|
|
|
|
def every_n_iters(self, trainer, n):
|
|
return (trainer.iter + 1) % n == 0 if n > 0 else False
|
|
|
|
def end_of_epoch(self, trainer):
|
|
return trainer.inner_iter + 1 == trainer.iters_per_epoch
|
|
|
|
def is_last_epoch(self, trainer):
|
|
return trainer.epoch + 1 == trainer.max_epochs
|
|
|
|
def is_last_iter(self, trainer):
|
|
return trainer.iter + 1 == trainer.max_iters
|
|
|
|
def get_triggered_stages(self):
|
|
return []
|
|
|
|
def state_dict(self):
|
|
return {}
|
|
|
|
def load_state_dict(self, state_dict):
|
|
pass
|
|
|
|
def after_iter(self, trainer):
|
|
raise MsRegressTool.EarlyStopError('Test finished.')
|
|
|
|
trainer.register_hook(EarlyStopHook())
|
|
|
|
def _train_loop(trainer, *args_train, **kwargs_train):
|
|
with self.monitor_module_train(
|
|
trainer,
|
|
file_name,
|
|
level,
|
|
compare_fn=compare_fn,
|
|
ignore_keys=ignore_keys,
|
|
compare_random=compare_random,
|
|
lazy_stop_callback=lazy_stop_callback,
|
|
**kwargs):
|
|
try:
|
|
return trainer.train_loop_origin(*args_train,
|
|
**kwargs_train)
|
|
except MsRegressTool.EarlyStopError:
|
|
pass
|
|
|
|
trainer.train_loop_origin, trainer.train_loop = \
|
|
trainer.train_loop, type(trainer.train_loop)(_train_loop, trainer)
|
|
yield
|
|
|
|
|
|
def compare_module(module1: nn.Module, module2: nn.Module):
|
|
for p1, p2 in zip(module1.parameters(), module2.parameters()):
|
|
if p1.data.ne(p2.data).sum() > 0:
|
|
return False
|
|
return True
|
|
|
|
|
|
def numpify_tensor_nested(tensors, reduction=None, clip_value=10000):
|
|
try:
|
|
from modelscope.outputs import ModelOutputBase
|
|
except ImportError:
|
|
ModelOutputBase = dict
|
|
"Numpify `tensors` (even if it's a nested list/tuple of tensors)."
|
|
if isinstance(tensors, (Mapping, ModelOutputBase)):
|
|
return OrderedDict({
|
|
k: numpify_tensor_nested(t, reduction, clip_value)
|
|
for k, t in tensors.items()
|
|
})
|
|
if isinstance(tensors, list):
|
|
return list(
|
|
numpify_tensor_nested(t, reduction, clip_value) for t in tensors)
|
|
if isinstance(tensors, tuple):
|
|
return tuple(
|
|
numpify_tensor_nested(t, reduction, clip_value) for t in tensors)
|
|
if isinstance(tensors, torch.Tensor):
|
|
t: np.ndarray = tensors.cpu().numpy()
|
|
if clip_value is not None:
|
|
t = np.where(t > clip_value, clip_value, t)
|
|
t = np.where(t < -clip_value, -clip_value, t)
|
|
if reduction == 'sum':
|
|
return t.sum(dtype=np.float)
|
|
elif reduction == 'mean':
|
|
return t.mean(dtype=np.float)
|
|
return t
|
|
return tensors
|
|
|
|
|
|
def detach_tensor_nested(tensors):
|
|
try:
|
|
from modelscope.outputs import ModelOutputBase
|
|
except ImportError:
|
|
ModelOutputBase = dict
|
|
"Detach `tensors` (even if it's a nested list/tuple of tensors)."
|
|
if isinstance(tensors, (Mapping, ModelOutputBase)):
|
|
return OrderedDict(
|
|
{k: detach_tensor_nested(t)
|
|
for k, t in tensors.items()})
|
|
if isinstance(tensors, list):
|
|
return list(detach_tensor_nested(t) for t in tensors)
|
|
if isinstance(tensors, tuple):
|
|
return tuple(detach_tensor_nested(t) for t in tensors)
|
|
if isinstance(tensors, torch.Tensor):
|
|
return tensors.detach()
|
|
return tensors
|
|
|
|
|
|
def hack_forward(module: nn.Module,
|
|
name,
|
|
io_json,
|
|
restore=False,
|
|
keep_tensors=False):
|
|
def _forward(self, *args, **kwargs):
|
|
ret = self.forward_origin(*args, **kwargs)
|
|
if keep_tensors:
|
|
args = numpify_tensor_nested(detach_tensor_nested(args))
|
|
kwargs = numpify_tensor_nested(detach_tensor_nested(kwargs))
|
|
output = numpify_tensor_nested(detach_tensor_nested(ret))
|
|
else:
|
|
args = {
|
|
'sum':
|
|
numpify_tensor_nested(detach_tensor_nested(args),
|
|
reduction='sum'),
|
|
'mean':
|
|
numpify_tensor_nested(detach_tensor_nested(args),
|
|
reduction='mean'),
|
|
}
|
|
kwargs = {
|
|
'sum':
|
|
numpify_tensor_nested(detach_tensor_nested(kwargs),
|
|
reduction='sum'),
|
|
'mean':
|
|
numpify_tensor_nested(detach_tensor_nested(kwargs),
|
|
reduction='mean'),
|
|
}
|
|
output = {
|
|
'sum':
|
|
numpify_tensor_nested(detach_tensor_nested(ret),
|
|
reduction='sum'),
|
|
'mean':
|
|
numpify_tensor_nested(detach_tensor_nested(ret),
|
|
reduction='mean'),
|
|
}
|
|
|
|
io_json[name] = {
|
|
'input': {
|
|
'args': args,
|
|
'kwargs': kwargs,
|
|
},
|
|
'output': output,
|
|
}
|
|
return ret
|
|
|
|
if not restore and not hasattr(module, 'forward_origin'):
|
|
module.forward_origin, module.forward = module.forward, type(
|
|
module.forward)(_forward, module)
|
|
if restore and hasattr(module, 'forward_origin'):
|
|
module.forward = module.forward_origin
|
|
del module.forward_origin
|
|
|
|
|
|
def hack_backward(module: nn.Module,
|
|
optimizer,
|
|
io_json,
|
|
restore=False,
|
|
lazy_stop_callback=None):
|
|
def _step(self, *args, **kwargs):
|
|
for name, param in module.named_parameters():
|
|
io_json[name] = {
|
|
'data': {
|
|
'sum':
|
|
numpify_tensor_nested(detach_tensor_nested(param.data),
|
|
reduction='sum'),
|
|
'mean':
|
|
numpify_tensor_nested(detach_tensor_nested(param.data),
|
|
reduction='mean'),
|
|
},
|
|
'grad': {
|
|
'sum':
|
|
numpify_tensor_nested(detach_tensor_nested(param.grad),
|
|
reduction='sum'),
|
|
'mean':
|
|
numpify_tensor_nested(detach_tensor_nested(param.grad),
|
|
reduction='mean'),
|
|
}
|
|
}
|
|
ret = self.step_origin(*args, **kwargs)
|
|
for name, param in module.named_parameters():
|
|
io_json[name]['data_after'] = {
|
|
'sum':
|
|
numpify_tensor_nested(detach_tensor_nested(param.data),
|
|
reduction='sum'),
|
|
'mean':
|
|
numpify_tensor_nested(detach_tensor_nested(param.data),
|
|
reduction='mean'),
|
|
}
|
|
if lazy_stop_callback is not None:
|
|
lazy_stop_callback()
|
|
return ret
|
|
|
|
if not restore and not hasattr(optimizer, 'step_origin'):
|
|
optimizer.step_origin, optimizer.step = optimizer.step, type(
|
|
optimizer.state_dict)(_step, optimizer)
|
|
if restore and hasattr(optimizer, 'step_origin'):
|
|
optimizer.step = optimizer.step_origin
|
|
del optimizer.step_origin
|
|
|
|
|
|
def intercept_module(module: nn.Module,
|
|
io_json,
|
|
parent_name=None,
|
|
restore=False):
|
|
for name, module in module.named_children():
|
|
full_name = parent_name + '.' + name if parent_name is not None else name
|
|
hack_forward(module, full_name, io_json, restore)
|
|
intercept_module(module, io_json, full_name, restore)
|
|
|
|
|
|
def compare_io_and_print(baseline_json, io_json, compare_fn=None, **kwargs):
|
|
if compare_fn is None:
|
|
|
|
def compare_fn(*args, **kwargs):
|
|
return None
|
|
|
|
keys1 = set(baseline_json.keys())
|
|
keys2 = set(io_json.keys())
|
|
added = keys1 - keys2
|
|
removed = keys2 - keys1
|
|
print(f'unmatched keys: {added}, {removed}')
|
|
shared_keys = keys1.intersection(keys2)
|
|
match = True
|
|
for key in shared_keys:
|
|
v1 = baseline_json[key]
|
|
v2 = io_json[key]
|
|
|
|
v1input = numpify_tensor_nested(v1['input'])
|
|
v2input = numpify_tensor_nested(v2['input'])
|
|
res = compare_fn(v1input, v2input, key, 'input')
|
|
if res is not None:
|
|
print(
|
|
f'input of {key} compared with user compare_fn with result:{res}\n'
|
|
)
|
|
match = match and res
|
|
else:
|
|
match = compare_arguments_nested(
|
|
f'unmatched module {key} input args', v1input['args'],
|
|
v2input['args'], **kwargs) and match
|
|
match = compare_arguments_nested(
|
|
f'unmatched module {key} input kwargs', v1input['kwargs'],
|
|
v2input['kwargs'], **kwargs) and match
|
|
v1output = numpify_tensor_nested(v1['output'])
|
|
v2output = numpify_tensor_nested(v2['output'])
|
|
res = compare_fn(v1output, v2output, key, 'output')
|
|
if res is not None:
|
|
print(
|
|
f'output of {key} compared with user compare_fn with result:{res}\n'
|
|
)
|
|
match = match and res
|
|
else:
|
|
match = compare_arguments_nested(f'unmatched module {key} outputs',
|
|
arg1=v1output,
|
|
arg2=v2output,
|
|
**kwargs) and match
|
|
return match
|
|
|
|
|
|
def compare_backward_and_print(baseline_json,
|
|
bw_json,
|
|
level,
|
|
ignore_keys=None,
|
|
compare_fn=None,
|
|
**kwargs):
|
|
if compare_fn is None:
|
|
|
|
def compare_fn(*args, **kwargs):
|
|
return None
|
|
|
|
keys1 = set(baseline_json.keys())
|
|
keys2 = set(bw_json.keys())
|
|
added = keys1 - keys2
|
|
removed = keys2 - keys1
|
|
print(f'unmatched backward keys: {added}, {removed}')
|
|
shared_keys = keys1.intersection(keys2)
|
|
match = True
|
|
for key in shared_keys:
|
|
if ignore_keys is not None and key in ignore_keys:
|
|
continue
|
|
|
|
res = compare_fn(baseline_json[key], bw_json[key], key, 'backward')
|
|
if res is not None:
|
|
print(f'backward data of {key} compared with '
|
|
f'user compare_fn with result:{res}\n')
|
|
match = match and res
|
|
else:
|
|
data1, grad1, data_after1 = baseline_json[key][
|
|
'data'], baseline_json[key]['grad'], baseline_json[key][
|
|
'data_after']
|
|
data2, grad2, data_after2 = bw_json[key]['data'], bw_json[key][
|
|
'grad'], bw_json[key]['data_after']
|
|
match = compare_arguments_nested(
|
|
f'unmatched module {key} tensor data',
|
|
arg1=data1,
|
|
arg2=data2,
|
|
**kwargs) and match
|
|
if level == 'strict':
|
|
match = compare_arguments_nested(
|
|
f'unmatched module {key} grad data',
|
|
arg1=grad1,
|
|
arg2=grad2,
|
|
**kwargs) and match
|
|
match = compare_arguments_nested(
|
|
f'unmatched module {key} data after step', data_after1,
|
|
data_after2, **kwargs) and match
|
|
return match
|
|
|
|
|
|
def compare_cfg_and_optimizers(baseline_json,
|
|
cfg_json,
|
|
compare_fn=None,
|
|
**kwargs):
|
|
if compare_fn is None:
|
|
|
|
def compare_fn(*args, **kwargs):
|
|
return None
|
|
|
|
optimizer1, lr_scheduler1, cfg1, state1 = baseline_json[
|
|
'optimizer'], baseline_json['lr_scheduler'], baseline_json[
|
|
'cfg'], baseline_json['state']
|
|
optimizer2, lr_scheduler2, cfg2, state2 = cfg_json['optimizer'], cfg_json[
|
|
'lr_scheduler'], cfg_json['cfg'], baseline_json['state']
|
|
|
|
match = True
|
|
res = compare_fn(optimizer1, optimizer2, None, 'optimizer')
|
|
if res is not None:
|
|
print(f'optimizer compared with user compare_fn with result:{res}\n')
|
|
match = match and res
|
|
else:
|
|
if optimizer1['type'] != optimizer2['type']:
|
|
print(
|
|
f"Optimizer type not equal:{optimizer1['type']} and {optimizer2['type']}"
|
|
)
|
|
match = compare_arguments_nested(
|
|
'unmatched optimizer defaults', optimizer1['defaults'],
|
|
optimizer2['defaults'], **kwargs) and match
|
|
match = compare_arguments_nested(
|
|
'unmatched optimizer state_dict', optimizer1['state_dict'],
|
|
optimizer2['state_dict'], **kwargs) and match
|
|
|
|
res = compare_fn(lr_scheduler1, lr_scheduler2, None, 'lr_scheduler')
|
|
if res is not None:
|
|
print(
|
|
f'lr_scheduler compared with user compare_fn with result:{res}\n')
|
|
match = match and res
|
|
else:
|
|
if lr_scheduler1['type'] != lr_scheduler2['type']:
|
|
print(
|
|
f"Optimizer type not equal:{lr_scheduler1['type']} and {lr_scheduler2['type']}"
|
|
)
|
|
match = compare_arguments_nested(
|
|
'unmatched lr_scheduler state_dict', lr_scheduler1['state_dict'],
|
|
lr_scheduler2['state_dict'], **kwargs) and match
|
|
|
|
res = compare_fn(cfg1, cfg2, None, 'cfg')
|
|
if res is not None:
|
|
print(f'cfg compared with user compare_fn with result:{res}\n')
|
|
match = match and res
|
|
else:
|
|
match = compare_arguments_nested(
|
|
'unmatched cfg', arg1=cfg1, arg2=cfg2, **kwargs) and match
|
|
|
|
res = compare_fn(state1, state2, None, 'state')
|
|
if res is not None:
|
|
print(
|
|
f'random state compared with user compare_fn with result:{res}\n')
|
|
match = match and res
|
|
else:
|
|
match = compare_arguments_nested('unmatched random state', state1,
|
|
state2, **kwargs) and match
|
|
|
|
return match
|
|
|
|
|
|
class IgnoreKeyFn:
|
|
def __init__(self, keys):
|
|
if isinstance(keys, str):
|
|
keys = [keys]
|
|
self.keys = keys if isinstance(keys, list) else []
|
|
|
|
def __call__(self, v1output, v2output, key, type):
|
|
for _key in self.keys:
|
|
pattern = re.compile(_key)
|
|
if key is not None and pattern.fullmatch(key):
|
|
return True
|
|
return None
|