mirror of
https://gitee.com/270580156/weiyu.git
synced 2026-05-19 21:57:49 +00:00
Sync from bytedesk-private: update
This commit is contained in:
0
modules/python/vendors/FunASR/funasr/train_utils/__init__.py
vendored
Normal file
0
modules/python/vendors/FunASR/funasr/train_utils/__init__.py
vendored
Normal file
31
modules/python/vendors/FunASR/funasr/train_utils/add_gradient_noise.py
vendored
Normal file
31
modules/python/vendors/FunASR/funasr/train_utils/add_gradient_noise.py
vendored
Normal file
@@ -0,0 +1,31 @@
|
||||
import torch
|
||||
|
||||
|
||||
def add_gradient_noise(
|
||||
model: torch.nn.Module,
|
||||
iteration: int,
|
||||
duration: float = 100,
|
||||
eta: float = 1.0,
|
||||
scale_factor: float = 0.55,
|
||||
):
|
||||
"""Adds noise from a standard normal distribution to the gradients.
|
||||
|
||||
The standard deviation (`sigma`) is controlled
|
||||
by the three hyper-parameters below.
|
||||
`sigma` goes to zero (no noise) with more iterations.
|
||||
|
||||
Args:
|
||||
model: Model.
|
||||
iteration: Number of iterations.
|
||||
duration: {100, 1000}: Number of durations to control
|
||||
the interval of the `sigma` change.
|
||||
eta: {0.01, 0.3, 1.0}: The magnitude of `sigma`.
|
||||
scale_factor: {0.55}: The scale of `sigma`.
|
||||
"""
|
||||
interval = (iteration // duration) + 1
|
||||
sigma = eta / interval**scale_factor
|
||||
for param in model.parameters():
|
||||
if param.grad is not None:
|
||||
_shape = param.grad.size()
|
||||
noise = sigma * torch.randn(_shape).to(param.device)
|
||||
param.grad += noise
|
||||
97
modules/python/vendors/FunASR/funasr/train_utils/average_nbest_models.py
vendored
Normal file
97
modules/python/vendors/FunASR/funasr/train_utils/average_nbest_models.py
vendored
Normal file
@@ -0,0 +1,97 @@
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from typing import Sequence
|
||||
from typing import Union
|
||||
import warnings
|
||||
import os
|
||||
from io import BytesIO
|
||||
|
||||
import torch
|
||||
from typing import Collection
|
||||
import os
|
||||
import torch
|
||||
import re
|
||||
from collections import OrderedDict
|
||||
from functools import cmp_to_key
|
||||
|
||||
|
||||
def _get_checkpoint_paths(output_dir: str, last_n: int = 5, use_deepspeed=False, **kwargs):
|
||||
"""
|
||||
Get the paths of the last 'last_n' checkpoints by parsing filenames
|
||||
in the output directory.
|
||||
"""
|
||||
try:
|
||||
if not use_deepspeed:
|
||||
checkpoint = torch.load(os.path.join(output_dir, "model.pt"), map_location="cpu")
|
||||
else:
|
||||
checkpoint = torch.load(
|
||||
os.path.join(output_dir, "model.pt", "mp_rank_00_model_states.pt"),
|
||||
map_location="cpu",
|
||||
)
|
||||
avg_keep_nbest_models_type = checkpoint["avg_keep_nbest_models_type"]
|
||||
val_step_or_epoch = checkpoint[f"val_{avg_keep_nbest_models_type}_step_or_epoch"]
|
||||
sorted_items = sorted(val_step_or_epoch.items(), key=lambda x: x[1], reverse=True)
|
||||
sorted_items = (
|
||||
sorted_items[:last_n] if avg_keep_nbest_models_type == "acc" else sorted_items[-last_n:]
|
||||
)
|
||||
checkpoint_paths = []
|
||||
for key, value in sorted_items[:last_n]:
|
||||
if not use_deepspeed:
|
||||
ckpt = os.path.join(output_dir, key)
|
||||
else:
|
||||
ckpt = os.path.join(output_dir, key, "mp_rank_00_model_states.pt")
|
||||
checkpoint_paths.append(ckpt)
|
||||
|
||||
except:
|
||||
print(f"{checkpoint} does not exist, avg the lastet checkpoint.")
|
||||
# List all files in the output directory
|
||||
files = os.listdir(output_dir)
|
||||
# Filter out checkpoint files and extract epoch numbers
|
||||
checkpoint_files = [f for f in files if f.startswith("model.pt.e")]
|
||||
# Sort files by epoch number in descending order
|
||||
checkpoint_files.sort(key=lambda x: int(re.search(r"(\d+)", x).group()), reverse=True)
|
||||
# Get the last 'last_n' checkpoint paths
|
||||
checkpoint_paths = [os.path.join(output_dir, f) for f in checkpoint_files[:last_n]]
|
||||
|
||||
return checkpoint_paths
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def average_checkpoints(output_dir: str, last_n: int = 5, **kwargs):
|
||||
"""
|
||||
Average the last 'last_n' checkpoints' model state_dicts.
|
||||
If a tensor is of type torch.int, perform sum instead of average.
|
||||
"""
|
||||
checkpoint_paths = _get_checkpoint_paths(output_dir, last_n, **kwargs)
|
||||
print(f"average_checkpoints: {checkpoint_paths}")
|
||||
state_dicts = []
|
||||
|
||||
# Load state_dicts from checkpoints
|
||||
for path in checkpoint_paths:
|
||||
if os.path.isfile(path):
|
||||
state_dicts.append(torch.load(path, map_location="cpu")["state_dict"])
|
||||
else:
|
||||
print(f"Checkpoint file {path} not found.")
|
||||
|
||||
# Check if we have any state_dicts to average
|
||||
if len(state_dicts) < 1:
|
||||
print("No checkpoints found for averaging.")
|
||||
return
|
||||
|
||||
# Average or sum weights
|
||||
avg_state_dict = OrderedDict()
|
||||
for key in state_dicts[0].keys():
|
||||
tensors = [state_dict[key].cpu() for state_dict in state_dicts]
|
||||
# Check the type of the tensor
|
||||
if str(tensors[0].dtype).startswith("torch.int"):
|
||||
# Perform sum for integer tensors
|
||||
summed_tensor = sum(tensors)
|
||||
avg_state_dict[key] = summed_tensor
|
||||
else:
|
||||
# Perform average for other types of tensors
|
||||
stacked_tensors = torch.stack(tensors)
|
||||
avg_state_dict[key] = torch.mean(stacked_tensors, dim=0)
|
||||
checkpoint_outpath = os.path.join(output_dir, f"model.pt.avg{last_n}")
|
||||
torch.save({"state_dict": avg_state_dict}, checkpoint_outpath)
|
||||
return checkpoint_outpath
|
||||
64
modules/python/vendors/FunASR/funasr/train_utils/device_funcs.py
vendored
Normal file
64
modules/python/vendors/FunASR/funasr/train_utils/device_funcs.py
vendored
Normal file
@@ -0,0 +1,64 @@
|
||||
import dataclasses
|
||||
import warnings
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
def to_device(data, device=None, dtype=None, non_blocking=False, copy=False):
|
||||
"""Change the device of object recursively"""
|
||||
if isinstance(data, dict):
|
||||
return {k: to_device(v, device, dtype, non_blocking, copy) for k, v in data.items()}
|
||||
elif dataclasses.is_dataclass(data) and not isinstance(data, type):
|
||||
return type(data)(
|
||||
*[to_device(v, device, dtype, non_blocking, copy) for v in dataclasses.astuple(data)]
|
||||
)
|
||||
# maybe namedtuple. I don't know the correct way to judge namedtuple.
|
||||
elif isinstance(data, tuple) and type(data) is not tuple:
|
||||
return type(data)(*[to_device(o, device, dtype, non_blocking, copy) for o in data])
|
||||
elif isinstance(data, (list, tuple)):
|
||||
return type(data)(to_device(v, device, dtype, non_blocking, copy) for v in data)
|
||||
elif isinstance(data, np.ndarray):
|
||||
return to_device(torch.from_numpy(data), device, dtype, non_blocking, copy)
|
||||
elif isinstance(data, torch.Tensor):
|
||||
return data.to(device, dtype, non_blocking, copy)
|
||||
else:
|
||||
return data
|
||||
|
||||
|
||||
def force_gatherable(data, device):
|
||||
"""Change object to gatherable in torch.nn.DataParallel recursively
|
||||
|
||||
The difference from to_device() is changing to torch.Tensor if float or int
|
||||
value is found.
|
||||
|
||||
The restriction to the returned value in DataParallel:
|
||||
The object must be
|
||||
- torch.cuda.Tensor
|
||||
- 1 or more dimension. 0-dimension-tensor sends warning.
|
||||
or a list, tuple, dict.
|
||||
|
||||
"""
|
||||
if isinstance(data, dict):
|
||||
return {k: force_gatherable(v, device) for k, v in data.items()}
|
||||
# DataParallel can't handle NamedTuple well
|
||||
elif isinstance(data, tuple) and type(data) is not tuple:
|
||||
return type(data)(*[force_gatherable(o, device) for o in data])
|
||||
elif isinstance(data, (list, tuple, set)):
|
||||
return type(data)(force_gatherable(v, device) for v in data)
|
||||
elif isinstance(data, np.ndarray):
|
||||
return force_gatherable(torch.from_numpy(data), device)
|
||||
elif isinstance(data, torch.Tensor):
|
||||
if data.dim() == 0:
|
||||
# To 1-dim array
|
||||
data = data[None]
|
||||
return data.to(device)
|
||||
elif isinstance(data, float):
|
||||
return torch.tensor([data], dtype=torch.float, device=device)
|
||||
elif isinstance(data, int):
|
||||
return torch.tensor([data], dtype=torch.long, device=device)
|
||||
elif data is None:
|
||||
return None
|
||||
else:
|
||||
warnings.warn(f"{type(data)} may not be gatherable by DataParallel")
|
||||
return data
|
||||
31
modules/python/vendors/FunASR/funasr/train_utils/forward_adaptor.py
vendored
Normal file
31
modules/python/vendors/FunASR/funasr/train_utils/forward_adaptor.py
vendored
Normal file
@@ -0,0 +1,31 @@
|
||||
import torch
|
||||
|
||||
|
||||
class ForwardAdaptor(torch.nn.Module):
|
||||
"""Wrapped module to parallelize specified method
|
||||
|
||||
torch.nn.DataParallel parallelizes only "forward()"
|
||||
and, maybe, the method having the other name can't be applied
|
||||
except for wrapping the module just like this class.
|
||||
|
||||
Examples:
|
||||
>>> class A(torch.nn.Module):
|
||||
... def foo(self, x):
|
||||
... ...
|
||||
>>> model = A()
|
||||
>>> model = ForwardAdaptor(model, "foo")
|
||||
>>> model = torch.nn.DataParallel(model, device_ids=[0, 1])
|
||||
>>> x = torch.randn(2, 10)
|
||||
>>> model(x)
|
||||
"""
|
||||
|
||||
def __init__(self, module: torch.nn.Module, name: str):
|
||||
super().__init__()
|
||||
self.module = module
|
||||
self.name = name
|
||||
if not hasattr(module, name):
|
||||
raise ValueError(f"{module} doesn't have {name}")
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
func = getattr(self.module, self.name)
|
||||
return func(*args, **kwargs)
|
||||
55
modules/python/vendors/FunASR/funasr/train_utils/initialize.py
vendored
Normal file
55
modules/python/vendors/FunASR/funasr/train_utils/initialize.py
vendored
Normal file
@@ -0,0 +1,55 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
"""Initialize modules for espnet2 neural networks."""
|
||||
|
||||
import math
|
||||
import torch
|
||||
|
||||
|
||||
def initialize(model: torch.nn.Module, init: str):
|
||||
"""Initialize weights of a neural network module.
|
||||
|
||||
Parameters are initialized using the given method or distribution.
|
||||
|
||||
Custom initialization routines can be implemented into submodules
|
||||
as function `espnet_initialization_fn` within the custom module.
|
||||
|
||||
Args:
|
||||
model: Target.
|
||||
init: Method of initialization.
|
||||
"""
|
||||
|
||||
# weight init
|
||||
for p in model.parameters():
|
||||
if p.dim() > 1:
|
||||
if init == "xavier_uniform":
|
||||
torch.nn.init.xavier_uniform_(p.data)
|
||||
elif init == "xavier_normal":
|
||||
torch.nn.init.xavier_normal_(p.data)
|
||||
elif init == "kaiming_uniform":
|
||||
torch.nn.init.kaiming_uniform_(p.data, nonlinearity="relu")
|
||||
elif init == "kaiming_normal":
|
||||
torch.nn.init.kaiming_normal_(p.data, nonlinearity="relu")
|
||||
else:
|
||||
raise ValueError("Unknown initialization: " + init)
|
||||
# bias init
|
||||
for p in model.parameters():
|
||||
if p.dim() == 1:
|
||||
p.data.zero_()
|
||||
|
||||
# reset some modules with default init
|
||||
for m in model.modules():
|
||||
if isinstance(m, (torch.nn.Embedding, torch.nn.LayerNorm, torch.nn.GroupNorm)):
|
||||
m.reset_parameters()
|
||||
if hasattr(m, "espnet_initialization_fn"):
|
||||
m.espnet_initialization_fn()
|
||||
|
||||
# TODO(xkc): Hacking s3prl_frontend and wav2vec2encoder initialization
|
||||
if getattr(model, "encoder", None) and getattr(
|
||||
model.encoder, "reload_pretrained_parameters", None
|
||||
):
|
||||
model.encoder.reload_pretrained_parameters()
|
||||
if getattr(model, "frontend", None) and getattr(
|
||||
model.frontend, "reload_pretrained_parameters", None
|
||||
):
|
||||
model.frontend.reload_pretrained_parameters()
|
||||
103
modules/python/vendors/FunASR/funasr/train_utils/load_pretrained_model.py
vendored
Normal file
103
modules/python/vendors/FunASR/funasr/train_utils/load_pretrained_model.py
vendored
Normal file
@@ -0,0 +1,103 @@
|
||||
from typing import Any
|
||||
from typing import Dict
|
||||
from typing import Union
|
||||
from io import BytesIO
|
||||
|
||||
import logging
|
||||
import torch
|
||||
import torch.nn
|
||||
import torch.optim
|
||||
import pdb
|
||||
import copy
|
||||
|
||||
|
||||
def load_pretrained_model(
|
||||
path: str,
|
||||
model: torch.nn.Module,
|
||||
ignore_init_mismatch: bool = True,
|
||||
map_location: str = "cpu",
|
||||
oss_bucket=None,
|
||||
scope_map=[],
|
||||
excludes=None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Load a model state and set it to the model.
|
||||
|
||||
Args:
|
||||
init_param: <file_path>:<src_key>:<dst_key>:<exclude_Keys>
|
||||
|
||||
Examples:
|
||||
|
||||
"""
|
||||
|
||||
obj = model
|
||||
dst_state = obj.state_dict()
|
||||
|
||||
logging.info(f"ckpt: {path}")
|
||||
|
||||
if oss_bucket is None:
|
||||
ori_state = torch.load(path, map_location=map_location)
|
||||
else:
|
||||
buffer = BytesIO(oss_bucket.get_object(path).read())
|
||||
ori_state = torch.load(buffer, map_location=map_location)
|
||||
|
||||
src_state = copy.deepcopy(ori_state)
|
||||
src_state = src_state["state_dict"] if "state_dict" in src_state else src_state
|
||||
src_state = src_state["model_state_dict"] if "model_state_dict" in src_state else src_state
|
||||
src_state = src_state["model"] if "model" in src_state else src_state
|
||||
|
||||
if isinstance(scope_map, str):
|
||||
scope_map = scope_map.split(",")
|
||||
scope_map += ["module.", "None"]
|
||||
logging.info(f"scope_map: {scope_map}")
|
||||
|
||||
if excludes is not None:
|
||||
if isinstance(excludes, str):
|
||||
excludes = excludes.split(",")
|
||||
|
||||
logging.info(f"excludes: {excludes}")
|
||||
|
||||
for k in dst_state.keys():
|
||||
excludes_flag = False
|
||||
if excludes is not None:
|
||||
for k_ex in excludes:
|
||||
if k.startswith(k_ex):
|
||||
logging.info(f"key: {k} matching: {k_ex}, excluded")
|
||||
excludes_flag = True
|
||||
break
|
||||
if excludes_flag:
|
||||
continue
|
||||
|
||||
k_src = k
|
||||
|
||||
if scope_map is not None:
|
||||
src_prefix = ""
|
||||
dst_prefix = ""
|
||||
for i in range(0, len(scope_map), 2):
|
||||
src_prefix = scope_map[i] if scope_map[i].lower() != "none" else ""
|
||||
dst_prefix = scope_map[i + 1] if scope_map[i + 1].lower() != "none" else ""
|
||||
|
||||
if dst_prefix == "" and (src_prefix + k) in src_state.keys():
|
||||
k_src = src_prefix + k
|
||||
if not k_src.startswith("module."):
|
||||
logging.info(f"init param, map: {k} from {k_src} in ckpt")
|
||||
elif (
|
||||
k.startswith(dst_prefix)
|
||||
and k.replace(dst_prefix, src_prefix, 1) in src_state.keys()
|
||||
):
|
||||
k_src = k.replace(dst_prefix, src_prefix, 1)
|
||||
if not k_src.startswith("module."):
|
||||
logging.info(f"init param, map: {k} from {k_src} in ckpt")
|
||||
|
||||
if k_src in src_state.keys():
|
||||
if ignore_init_mismatch and dst_state[k].shape != src_state[k_src].shape:
|
||||
logging.info(
|
||||
f"ignore_init_mismatch:{ignore_init_mismatch}, dst: {k, dst_state[k].shape}, src: {k_src, src_state[k_src].shape}"
|
||||
)
|
||||
else:
|
||||
dst_state[k] = src_state[k_src]
|
||||
else:
|
||||
print(f"Warning, miss key in ckpt: {k}, {path}")
|
||||
|
||||
flag = obj.load_state_dict(dst_state, strict=True)
|
||||
logging.info(f"Loading ckpt: {path}, status: {flag}")
|
||||
72
modules/python/vendors/FunASR/funasr/train_utils/model_summary.py
vendored
Normal file
72
modules/python/vendors/FunASR/funasr/train_utils/model_summary.py
vendored
Normal file
@@ -0,0 +1,72 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
def get_human_readable_count(number: int) -> str:
|
||||
"""Return human_readable_count
|
||||
|
||||
Originated from:
|
||||
https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pytorch_lightning/core/memory.py
|
||||
|
||||
Abbreviates an integer number with K, M, B, T for thousands, millions,
|
||||
billions and trillions, respectively.
|
||||
Examples:
|
||||
>>> get_human_readable_count(123)
|
||||
'123 '
|
||||
>>> get_human_readable_count(1234) # (one thousand)
|
||||
'1 K'
|
||||
>>> get_human_readable_count(2e6) # (two million)
|
||||
'2 M'
|
||||
>>> get_human_readable_count(3e9) # (three billion)
|
||||
'3 B'
|
||||
>>> get_human_readable_count(4e12) # (four trillion)
|
||||
'4 T'
|
||||
>>> get_human_readable_count(5e15) # (more than trillion)
|
||||
'5,000 T'
|
||||
Args:
|
||||
number: a positive integer number
|
||||
Return:
|
||||
A string formatted according to the pattern described above.
|
||||
"""
|
||||
assert number >= 0
|
||||
labels = [" ", "K", "M", "B", "T"]
|
||||
num_digits = int(np.floor(np.log10(number)) + 1 if number > 0 else 1)
|
||||
num_groups = int(np.ceil(num_digits / 3))
|
||||
num_groups = min(num_groups, len(labels)) # don't abbreviate beyond trillions
|
||||
shift = -3 * (num_groups - 1)
|
||||
number = number * (10**shift)
|
||||
index = num_groups - 1
|
||||
return f"{number:.2f} {labels[index]}"
|
||||
|
||||
|
||||
def to_bytes(dtype) -> int:
|
||||
# torch.float16 -> 16
|
||||
return int(str(dtype)[-2:]) // 8
|
||||
|
||||
|
||||
def model_summary(model: torch.nn.Module) -> str:
|
||||
message = "Model structure:\n"
|
||||
message += str(model)
|
||||
|
||||
tot_params, num_params = 0, 0
|
||||
for name, param in model.named_parameters():
|
||||
print(
|
||||
"name: {}, dtype: {}, device: {}, trainable: {}, shape: {}, numel: {}".format(
|
||||
name, param.dtype, param.device, param.requires_grad, param.shape, param.numel()
|
||||
)
|
||||
)
|
||||
tot_params += param.numel()
|
||||
if param.requires_grad:
|
||||
num_params += param.numel()
|
||||
|
||||
percent_trainable = "{:.1f}".format(num_params * 100.0 / tot_params)
|
||||
tot_params = get_human_readable_count(tot_params)
|
||||
num_params = get_human_readable_count(num_params)
|
||||
message += "\n\nModel summary:\n"
|
||||
message += f" Class Name: {model.__class__.__name__}\n"
|
||||
message += f" Total Number of model parameters: {tot_params}\n"
|
||||
message += f" Number of trainable parameters: {num_params} ({percent_trainable}%)\n"
|
||||
|
||||
dtype = next(iter(model.parameters())).dtype
|
||||
message += f" Type: {dtype}"
|
||||
return message
|
||||
48
modules/python/vendors/FunASR/funasr/train_utils/recursive_op.py
vendored
Normal file
48
modules/python/vendors/FunASR/funasr/train_utils/recursive_op.py
vendored
Normal file
@@ -0,0 +1,48 @@
|
||||
"""Torch utility module."""
|
||||
|
||||
import torch
|
||||
|
||||
if torch.distributed.is_available():
|
||||
from torch.distributed import ReduceOp
|
||||
|
||||
|
||||
def recursive_sum(obj, weight: torch.Tensor, distributed: bool = False):
|
||||
assert weight.dim() == 1, weight.size()
|
||||
if isinstance(obj, (tuple, list)):
|
||||
return type(obj)(recursive_sum(v, weight, distributed) for v in obj)
|
||||
elif isinstance(obj, dict):
|
||||
return {k: recursive_sum(v, weight, distributed) for k, v in obj.items()}
|
||||
elif isinstance(obj, torch.Tensor):
|
||||
assert obj.size() == weight.size(), (obj.size(), weight.size())
|
||||
obj = (obj * weight.type(obj.dtype)).sum()
|
||||
if distributed:
|
||||
torch.distributed.all_reduce(obj, op=ReduceOp.SUM)
|
||||
return obj
|
||||
elif obj is None:
|
||||
return None
|
||||
else:
|
||||
raise ValueError(type(obj))
|
||||
|
||||
|
||||
def recursive_divide(a, b: torch.Tensor):
|
||||
if isinstance(a, (tuple, list)):
|
||||
return type(a)(recursive_divide(v, b) for v in a)
|
||||
elif isinstance(a, dict):
|
||||
return {k: recursive_divide(v, b) for k, v in a.items()}
|
||||
elif isinstance(a, torch.Tensor):
|
||||
assert a.size() == b.size(), (a.size(), b.size())
|
||||
return a / b.type(a.dtype)
|
||||
elif a is None:
|
||||
return None
|
||||
else:
|
||||
raise ValueError(type(a))
|
||||
|
||||
|
||||
def recursive_average(obj, weight: torch.Tensor, distributed: bool = False):
|
||||
obj = recursive_sum(obj, weight, distributed)
|
||||
weight = weight.sum()
|
||||
if distributed:
|
||||
torch.distributed.all_reduce(weight, op=ReduceOp.SUM)
|
||||
# Normalize weight to be sum-to-1
|
||||
obj = recursive_divide(obj, weight)
|
||||
return obj, weight
|
||||
10
modules/python/vendors/FunASR/funasr/train_utils/set_all_random_seed.py
vendored
Normal file
10
modules/python/vendors/FunASR/funasr/train_utils/set_all_random_seed.py
vendored
Normal file
@@ -0,0 +1,10 @@
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
def set_all_random_seed(seed: int):
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.random.manual_seed(seed)
|
||||
732
modules/python/vendors/FunASR/funasr/train_utils/trainer.py
vendored
Normal file
732
modules/python/vendors/FunASR/funasr/train_utils/trainer.py
vendored
Normal file
@@ -0,0 +1,732 @@
|
||||
import math
|
||||
import os
|
||||
import time
|
||||
import torch
|
||||
import logging
|
||||
from tqdm import tqdm
|
||||
from datetime import datetime
|
||||
import torch.distributed as dist
|
||||
from torch.cuda.amp import autocast, GradScaler
|
||||
from contextlib import nullcontext, contextmanager
|
||||
from pathlib import Path
|
||||
|
||||
from funasr.train_utils.device_funcs import to_device
|
||||
from funasr.train_utils.recursive_op import recursive_average
|
||||
from funasr.train_utils.average_nbest_models import average_checkpoints
|
||||
from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
|
||||
|
||||
try:
|
||||
import wandb
|
||||
except:
|
||||
wandb = None
|
||||
|
||||
|
||||
@contextmanager
|
||||
def maybe_autocast(enabled):
|
||||
if enabled:
|
||||
with autocast():
|
||||
yield
|
||||
else:
|
||||
yield
|
||||
|
||||
|
||||
class Trainer:
|
||||
"""
|
||||
A simple trainer class for training a PyTorch model, saving checkpoints at the end of each epoch,
|
||||
and optionally resuming from a saved checkpoint.
|
||||
|
||||
Attributes:
|
||||
max_epoch (int): Maximum number of epochs for training.
|
||||
model (torch.nn.Module): The model to be trained.
|
||||
optim (torch.optim.Optimizer): The optimizer to use for training.
|
||||
scheduler (torch.optim.lr_scheduler._LRScheduler): The learning rate scheduler.
|
||||
dataloader_train (torch.utils.data.DataLoader): DataLoader for the training dataset.
|
||||
dataloader_val (torch.utils.data.DataLoader): DataLoader for the validation dataset.
|
||||
output_dir (str): Directory where model checkpoints will be saved.
|
||||
resume (str, optional): Path to a checkpoint to resume training from.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
local_rank,
|
||||
use_ddp: bool = False,
|
||||
use_fsdp: bool = False,
|
||||
use_fp16: bool = False,
|
||||
output_dir: str = "./",
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Initializes the Trainer class with the model, optimizer, scheduler, dataloaders, and other settings.
|
||||
|
||||
Args:
|
||||
model (torch.nn.Module): The model to be trained.
|
||||
optim (torch.optim.Optimizer): The optimizer to use for training.
|
||||
scheduler (torch.optim.lr_scheduler._LRScheduler): The learning rate scheduler.
|
||||
dataloader_train (torch.utils.data.DataLoader): The DataLoader for the training dataset.
|
||||
dataloader_val (torch.utils.data.DataLoader): The DataLoader for the validation dataset.
|
||||
**kwargs: Additional keyword arguments:
|
||||
max_epoch (int): The maximum number of epochs for training.
|
||||
output_dir (str): The directory where model checkpoints will be saved. Default is './'.
|
||||
resume (str, optional): The file path to a checkpoint to resume training from.
|
||||
"""
|
||||
|
||||
self.output_dir = output_dir
|
||||
if not os.path.exists(self.output_dir):
|
||||
os.makedirs(self.output_dir, exist_ok=True)
|
||||
self.resume = kwargs.get("resume", True)
|
||||
self.start_epoch = 0
|
||||
self.max_epoch = kwargs.get("max_epoch", 100)
|
||||
self.local_rank = local_rank
|
||||
self.use_ddp = use_ddp
|
||||
self.use_fsdp = use_fsdp
|
||||
self.device = kwargs.get("device", "cuda")
|
||||
# self.kwargs = kwargs
|
||||
self.log_interval = kwargs.get("log_interval", 50)
|
||||
self.batch_total = 0
|
||||
self.use_fp16 = use_fp16
|
||||
self.save_checkpoint_interval = kwargs.get("save_checkpoint_interval", 5000)
|
||||
self.validate_interval = kwargs.get("validate_interval", -1)
|
||||
if self.validate_interval < 0:
|
||||
self.validate_interval = self.save_checkpoint_interval
|
||||
assert (
|
||||
self.save_checkpoint_interval == self.validate_interval
|
||||
), f"save_checkpoint_interval must equal to validate_interval"
|
||||
self.keep_nbest_models = kwargs.get("keep_nbest_models", 500)
|
||||
self.avg_keep_nbest_models_type = kwargs.get("avg_keep_nbest_models_type", "acc")
|
||||
self.avg_nbest_model = kwargs.get("avg_nbest_model", 10)
|
||||
self.accum_grad = kwargs.get("accum_grad", 1)
|
||||
self.grad_clip = kwargs.get("grad_clip", 10.0)
|
||||
self.grad_clip_type = kwargs.get("grad_clip_type", 2.0)
|
||||
|
||||
try:
|
||||
rank = dist.get_rank()
|
||||
world_size = dist.get_world_size()
|
||||
except:
|
||||
rank = 0
|
||||
world_size = 1
|
||||
logging.warning("distributed is not initialized, only single shard")
|
||||
self.rank = rank
|
||||
self.world_size = world_size
|
||||
self.train_acc_avg = 0.0
|
||||
self.train_loss_avg = 0.0
|
||||
self.val_acc_avg = 0.0
|
||||
self.val_loss_avg = 0.0
|
||||
self.best_acc_idx = 0
|
||||
self.saved_ckpts = {}
|
||||
self.step_or_epoch = -1
|
||||
self.best_step_or_epoch = ""
|
||||
self.val_acc_step_or_epoch = {}
|
||||
self.val_loss_step_or_epoch = {}
|
||||
|
||||
self.reset_gpu_cache = kwargs.get("reset_gpu_cache", False)
|
||||
self.start_data_split_i = 0
|
||||
self.start_step = 0
|
||||
self.step_in_epoch = 0
|
||||
self.use_wandb = kwargs.get("use_wandb", False)
|
||||
if self.use_wandb:
|
||||
wandb.login(key=kwargs.get("wandb_token"))
|
||||
wandb.init(
|
||||
config=kwargs,
|
||||
project=kwargs.get("wandb_project", "my_project"),
|
||||
entity=kwargs.get("wandb_team", "my_team"),
|
||||
name=kwargs.get("wandb_exp_name", "my_exp"),
|
||||
dir=output_dir,
|
||||
job_type="training",
|
||||
reinit=True,
|
||||
)
|
||||
|
||||
def save_checkpoint(
|
||||
self,
|
||||
epoch,
|
||||
step=None,
|
||||
model=None,
|
||||
optim=None,
|
||||
scheduler=None,
|
||||
scaler=None,
|
||||
step_in_epoch=None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Saves a checkpoint containing the model's state, the optimizer's state,
|
||||
and the scheduler's state at the end of the given epoch. This method is
|
||||
intended to be called at the end of each epoch to save the training progress.
|
||||
|
||||
Args:
|
||||
epoch (int): The epoch number at which the checkpoint is being saved.
|
||||
"""
|
||||
|
||||
step_in_epoch = None if step is None else step_in_epoch
|
||||
if self.rank == 0:
|
||||
logging.info(f"Save checkpoint: {epoch}, rank: {self.local_rank}\n")
|
||||
# self.step_or_epoch += 1
|
||||
state = {
|
||||
"epoch": epoch,
|
||||
'step': step,
|
||||
'total_step': self.batch_total,
|
||||
"state_dict": model.state_dict(),
|
||||
"optimizer": optim.state_dict(),
|
||||
"scheduler": scheduler.state_dict(),
|
||||
"saved_ckpts": self.saved_ckpts,
|
||||
"val_acc_step_or_epoch": self.val_acc_step_or_epoch,
|
||||
"val_loss_step_or_epoch": self.val_loss_step_or_epoch,
|
||||
"best_step_or_epoch": self.best_step_or_epoch,
|
||||
"avg_keep_nbest_models_type": self.avg_keep_nbest_models_type,
|
||||
"step": step,
|
||||
"step_in_epoch": step_in_epoch,
|
||||
"data_split_i": kwargs.get("data_split_i", 0),
|
||||
"data_split_num": kwargs.get("data_split_num", 1),
|
||||
"batch_total": self.batch_total,
|
||||
"train_loss_avg": kwargs.get("train_loss_avg", 0),
|
||||
"train_acc_avg": kwargs.get("train_acc_avg", 0),
|
||||
}
|
||||
step = step_in_epoch
|
||||
if hasattr(model, "module"):
|
||||
state["state_dict"] = model.module.state_dict()
|
||||
|
||||
if scaler:
|
||||
state["scaler_state"] = scaler.state_dict()
|
||||
|
||||
# Create output directory if it does not exist
|
||||
os.makedirs(self.output_dir, exist_ok=True)
|
||||
if step is None:
|
||||
ckpt_name = f"model.pt.ep{epoch}"
|
||||
else:
|
||||
ckpt_name = f"model.pt.ep{epoch}.{step}"
|
||||
filename = os.path.join(self.output_dir, ckpt_name)
|
||||
torch.save(state, filename)
|
||||
logging.info(f'Checkpoint saved to {filename}')
|
||||
|
||||
latest = Path(os.path.join(self.output_dir, f'model.pt'))
|
||||
torch.save(state, latest)
|
||||
|
||||
if self.best_step_or_epoch == "":
|
||||
self.best_step_or_epoch = ckpt_name
|
||||
|
||||
if self.avg_keep_nbest_models_type == "acc":
|
||||
if (
|
||||
self.val_acc_step_or_epoch[ckpt_name]
|
||||
>= self.val_acc_step_or_epoch[self.best_step_or_epoch]
|
||||
):
|
||||
self.best_step_or_epoch = ckpt_name
|
||||
best_ckpt = Path(os.path.join(self.output_dir, f"model.pt.best"))
|
||||
torch.save(state, best_ckpt)
|
||||
logging.info(
|
||||
f"Update best acc: {self.val_acc_step_or_epoch[self.best_step_or_epoch]:.4f}, {best_ckpt}"
|
||||
)
|
||||
else:
|
||||
logging.info(
|
||||
f"No improvement in acc: {self.val_acc_step_or_epoch[ckpt_name]:.4f} < {self.val_acc_step_or_epoch[self.best_step_or_epoch]:.4f}, {os.path.join(self.output_dir, self.best_step_or_epoch)}"
|
||||
)
|
||||
elif self.avg_keep_nbest_models_type == "loss":
|
||||
if (
|
||||
self.val_loss_step_or_epoch[ckpt_name]
|
||||
<= self.val_loss_step_or_epoch[self.best_step_or_epoch]
|
||||
):
|
||||
self.best_step_or_epoch = ckpt_name
|
||||
best_ckpt = Path(os.path.join(self.output_dir, f"model.pt.best"))
|
||||
torch.save(state, best_ckpt)
|
||||
logging.info(
|
||||
f"Update best loss: {self.val_loss_step_or_epoch[self.best_step_or_epoch]:.4f}, {best_ckpt}"
|
||||
)
|
||||
else:
|
||||
logging.info(
|
||||
f"No improvement in loss: {self.val_loss_step_or_epoch[ckpt_name]:.4f} > {self.val_loss_step_or_epoch[self.best_step_or_epoch]:.4f}, {os.path.join(self.output_dir, self.best_step_or_epoch)}"
|
||||
)
|
||||
else:
|
||||
print("Undo")
|
||||
self.saved_ckpts[ckpt_name] = getattr(
|
||||
self, f"val_{self.avg_keep_nbest_models_type}_step_or_epoch"
|
||||
)[ckpt_name]
|
||||
if self.keep_nbest_models > 0:
|
||||
if len(self.saved_ckpts) > self.keep_nbest_models:
|
||||
if self.avg_keep_nbest_models_type == "acc":
|
||||
key = min(self.saved_ckpts, key=self.saved_ckpts.get)
|
||||
else:
|
||||
key = max(self.saved_ckpts, key=self.saved_ckpts.get)
|
||||
if key in self.saved_ckpts:
|
||||
del self.saved_ckpts[key]
|
||||
filename = os.path.join(self.output_dir, key)
|
||||
logging.info(f"Delete: {filename}")
|
||||
if os.path.exists(filename):
|
||||
os.remove(filename)
|
||||
|
||||
if self.use_ddp or self.use_fsdp:
|
||||
dist.barrier()
|
||||
|
||||
def resume_checkpoint(
|
||||
self,
|
||||
model=None,
|
||||
optim=None,
|
||||
scheduler=None,
|
||||
scaler=None,
|
||||
):
|
||||
"""
|
||||
Resumes training from a checkpoint at the given file path.
|
||||
Loads the model's state, the optimizer's state, and the scheduler's state.
|
||||
|
||||
Args:
|
||||
resume_path (str): The file path to the checkpoint to resume from.
|
||||
"""
|
||||
if self.resume:
|
||||
ckpt = os.path.join(self.output_dir, "model.pt")
|
||||
if os.path.isfile(ckpt):
|
||||
checkpoint = torch.load(ckpt, map_location="cpu")
|
||||
self.start_epoch = checkpoint["epoch"]
|
||||
# self.model.load_state_dict(checkpoint['state_dict'])
|
||||
src_state = checkpoint["state_dict"]
|
||||
dst_state = model.state_dict()
|
||||
for k in dst_state.keys():
|
||||
if not k.startswith("module.") and "module." + k in src_state.keys():
|
||||
k_ddp = "module." + k
|
||||
elif k.startswith("module.") and "module." + k not in src_state.keys():
|
||||
k_ddp = k.replace("module.", "", 1)
|
||||
else:
|
||||
k_ddp = k
|
||||
|
||||
if k_ddp in src_state.keys():
|
||||
dst_state[k] = src_state[k_ddp]
|
||||
else:
|
||||
print(f"Miss key in ckpt: model: {k}, ckpt: {k_ddp}")
|
||||
|
||||
model.load_state_dict(dst_state)
|
||||
optim.load_state_dict(checkpoint["optimizer"])
|
||||
scheduler.load_state_dict(checkpoint["scheduler"])
|
||||
if scaler is not None and "scaler_state" in checkpoint:
|
||||
scaler.load_state_dict(checkpoint["scaler_state"])
|
||||
|
||||
self.saved_ckpts = checkpoint["saved_ckpts"]
|
||||
self.val_acc_step_or_epoch = (
|
||||
checkpoint["val_acc_step_or_epoch"]
|
||||
if "val_acc_step_or_epoch" in checkpoint
|
||||
else {}
|
||||
)
|
||||
self.val_loss_step_or_epoch = (
|
||||
checkpoint["val_loss_step_or_epoch"]
|
||||
if "val_loss_step_or_epoch" in checkpoint
|
||||
else {}
|
||||
)
|
||||
self.best_step_or_epoch = (
|
||||
checkpoint["best_step_or_epoch"] if "best_step_or_epoch" in checkpoint else ""
|
||||
)
|
||||
self.start_data_split_i = (
|
||||
checkpoint["data_split_i"] if "data_split_i" in checkpoint else 0
|
||||
)
|
||||
self.batch_total = checkpoint["batch_total"] if "batch_total" in checkpoint else 0
|
||||
self.start_step = checkpoint["step"] if "step" in checkpoint else 0
|
||||
self.start_step = 0 if self.start_step is None else self.start_step
|
||||
self.step_in_epoch = (
|
||||
checkpoint["step_in_epoch"] if "step_in_epoch" in checkpoint else 0
|
||||
)
|
||||
self.step_in_epoch = 0 if self.step_in_epoch is None else self.step_in_epoch
|
||||
print(checkpoint["train_acc_avg"])
|
||||
self.train_acc_avg = (
|
||||
checkpoint["train_acc_avg"] if "train_acc_avg" in checkpoint else 0
|
||||
)
|
||||
self.train_loss_avg = (
|
||||
checkpoint["train_loss_avg"] if "train_loss_avg" in checkpoint else 0
|
||||
)
|
||||
model.to(self.device)
|
||||
print(f"Checkpoint loaded successfully from '{ckpt}'")
|
||||
else:
|
||||
print(f"No checkpoint found at '{ckpt}', does not resume status!")
|
||||
|
||||
if self.use_ddp or self.use_fsdp:
|
||||
dist.barrier()
|
||||
|
||||
|
||||
def train_epoch(
|
||||
self,
|
||||
model=None,
|
||||
optim=None,
|
||||
scheduler=None,
|
||||
scaler=None,
|
||||
dataloader_train=None,
|
||||
dataloader_val=None,
|
||||
epoch=None,
|
||||
writer=None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Defines the training process for a single epoch with gradient accumulation.
|
||||
Args:
|
||||
epoch (int): The current epoch number.
|
||||
"""
|
||||
if self.use_ddp or self.use_fsdp:
|
||||
dist.barrier()
|
||||
logging.info(f"Train epoch: {epoch}, rank: {self.rank}\n")
|
||||
model.train()
|
||||
|
||||
# Set the number of steps for gradient accumulation
|
||||
accum_grad = self.accum_grad
|
||||
# Initialize the gradient accumulation
|
||||
optim.zero_grad()
|
||||
speed_stats = {}
|
||||
|
||||
iterator_stop = torch.tensor(0).to(self.device)
|
||||
|
||||
dataloader_train.batch_sampler.set_epoch(epoch)
|
||||
time_beg = time.perf_counter()
|
||||
time5 = time_beg
|
||||
for batch_idx, batch in enumerate(dataloader_train):
|
||||
# if self.use_ddp or self.use_fsdp:
|
||||
# dist.all_reduce(iterator_stop, dist.ReduceOp.SUM)
|
||||
# if iterator_stop > 0:
|
||||
# break
|
||||
self.batch_total += 1
|
||||
self.step_in_epoch += 1
|
||||
time1 = time.perf_counter()
|
||||
speed_stats["data_load"] = f"{time1-time_beg:0.3f}"
|
||||
|
||||
batch = to_device(batch, self.device)
|
||||
|
||||
my_context = nullcontext
|
||||
if self.use_ddp or self.use_fsdp:
|
||||
my_context = model.no_sync if batch_idx % accum_grad != 0 else my_context
|
||||
with my_context():
|
||||
time2 = time.perf_counter()
|
||||
with maybe_autocast(self.use_fp16):
|
||||
retval = model(**batch)
|
||||
|
||||
# if (
|
||||
# self.reset_gpu_cache
|
||||
# and (torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024) > 70
|
||||
# ):
|
||||
# torch.cuda.empty_cache()
|
||||
|
||||
loss, stats, weight = retval
|
||||
stats = {k: v for k, v in stats.items() if v is not None}
|
||||
if self.use_ddp or self.use_fsdp:
|
||||
# Apply weighted averaging for loss and stats
|
||||
loss = (loss * weight.type(loss.dtype)).sum()
|
||||
# if distributed, this method can also apply all_reduce()
|
||||
# stats, weight = recursive_average(stats, weight, distributed=True)
|
||||
if self.use_ddp or self.use_fsdp:
|
||||
dist.all_reduce(weight, op=dist.ReduceOp.SUM)
|
||||
# Now weight is summation over all workers
|
||||
loss /= weight.sum() # shape:[1] -> shape:[]
|
||||
# Multiply world_size because DistributedDataParallel
|
||||
# automatically normalizes the gradient by world_size.
|
||||
loss *= self.world_size
|
||||
# loss *= self.world_size
|
||||
# Scale the loss since we're not updating for every mini-batch
|
||||
loss = loss / accum_grad
|
||||
|
||||
time3 = time.perf_counter()
|
||||
speed_stats["forward_time"] = f"{time3 - time2:0.3f}"
|
||||
if self.use_fp16:
|
||||
scaler.scale(loss).backward()
|
||||
else:
|
||||
loss.backward()
|
||||
time4 = time.perf_counter()
|
||||
speed_stats["backward_and_AllReaduce_time"] = f"{time4 - time3:0.3f}"
|
||||
|
||||
self.train_loss_avg = (
|
||||
self.train_loss_avg * (batch_idx + kwargs.get("start_step", 0))
|
||||
+ loss.detach().cpu().item()
|
||||
) / (batch_idx + kwargs.get("start_step", 0) + 1)
|
||||
if "acc" in stats:
|
||||
self.train_acc_avg = (
|
||||
self.train_acc_avg * (batch_idx + kwargs.get("start_step", 0))
|
||||
+ stats["acc"].detach().cpu().item()
|
||||
) / (batch_idx + kwargs.get("start_step", 0) + 1)
|
||||
|
||||
# Perform an optimizer step only after accumulating enough gradients
|
||||
if (batch_idx + 1) % accum_grad == 0:
|
||||
# Perform gradient clipping if it is set
|
||||
if self.grad_clip > 0:
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||
model.parameters(),
|
||||
max_norm=self.grad_clip,
|
||||
norm_type=self.grad_clip_type,
|
||||
)
|
||||
if not torch.isfinite(grad_norm):
|
||||
logging.warning(
|
||||
f"The grad norm is {grad_norm}. Skipping updating the model."
|
||||
)
|
||||
optim.zero_grad() # Reset gradients
|
||||
continue
|
||||
|
||||
# Execute an optimization step (update model parameters)
|
||||
if self.use_ddp or self.use_fsdp:
|
||||
dist.barrier()
|
||||
if self.use_fp16:
|
||||
scaler.step(optim)
|
||||
scaler.update()
|
||||
else:
|
||||
optim.step()
|
||||
scheduler.step()
|
||||
# Clear gradients for the next accumulation stage
|
||||
optim.zero_grad(set_to_none=True)
|
||||
|
||||
if self.use_ddp or self.use_fsdp:
|
||||
train_loss_avg = torch.tensor(self.train_loss_avg, dtype=torch.float32).to(
|
||||
self.device
|
||||
)
|
||||
train_acc_avg = torch.tensor(self.train_acc_avg, dtype=torch.float32).to(
|
||||
self.device
|
||||
)
|
||||
dist.all_reduce(train_loss_avg, op=dist.ReduceOp.SUM)
|
||||
dist.all_reduce(train_acc_avg, op=dist.ReduceOp.SUM)
|
||||
self.train_loss_avg = train_loss_avg.detach().cpu().item() / self.world_size
|
||||
self.train_acc_avg = train_acc_avg.detach().cpu().item() / self.world_size
|
||||
|
||||
total_time = f"{(time.perf_counter() - time5)/accum_grad:0.3f}"
|
||||
time5 = time.perf_counter()
|
||||
|
||||
speed_stats["optim_time"] = f"{time5 - time4:0.3f}"
|
||||
|
||||
speed_stats["total_time"] = total_time
|
||||
lr = scheduler.get_last_lr()[0]
|
||||
batch_num_epoch = 1
|
||||
if hasattr(dataloader_train, "__len__"):
|
||||
batch_num_epoch = len(dataloader_train)
|
||||
self.log(
|
||||
epoch,
|
||||
batch_idx,
|
||||
log_step=batch_idx + kwargs.get("start_step", 0),
|
||||
step_in_epoch=self.step_in_epoch,
|
||||
batch_num_epoch=batch_num_epoch,
|
||||
lr=lr,
|
||||
loss=accum_grad * loss.detach().cpu().item(),
|
||||
speed_stats=speed_stats,
|
||||
stats=stats,
|
||||
writer=writer,
|
||||
tag="train",
|
||||
data_split_i=kwargs.get("data_split_i", 0),
|
||||
data_split_num=kwargs.get("data_split_num", 1),
|
||||
)
|
||||
|
||||
if self.step_in_epoch % self.validate_interval == 0:
|
||||
self.validate_epoch(
|
||||
model=model,
|
||||
dataloader_val=dataloader_val,
|
||||
epoch=epoch,
|
||||
writer=writer,
|
||||
step=batch_idx + 1,
|
||||
step_in_epoch=self.step_in_epoch,
|
||||
)
|
||||
|
||||
if self.step_in_epoch % self.save_checkpoint_interval == 0:
|
||||
self.save_checkpoint(
|
||||
epoch,
|
||||
model=model,
|
||||
optim=optim,
|
||||
scheduler=scheduler,
|
||||
scaler=scaler,
|
||||
step=batch_idx + 1,
|
||||
step_in_epoch=self.step_in_epoch,
|
||||
data_split_i=kwargs.get("data_split_i", 0),
|
||||
data_split_num=kwargs.get("data_split_num", 1),
|
||||
train_loss_avg=self.train_loss_avg,
|
||||
train_acc_avg=self.train_acc_avg,
|
||||
)
|
||||
|
||||
time_beg = time.perf_counter()
|
||||
# else:
|
||||
# if self.use_ddp or self.use_fsdp:
|
||||
# iterator_stop.fill_(1)
|
||||
# dist.all_reduce(iterator_stop, dist.ReduceOp.SUM)
|
||||
|
||||
if self.use_ddp or self.use_fsdp:
|
||||
dist.barrier()
|
||||
# iterator_stop = torch.tensor(0).to(self.device)
|
||||
|
||||
def validate_epoch(
|
||||
self,
|
||||
model=None,
|
||||
dataloader_val=None,
|
||||
epoch=None,
|
||||
writer=None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Defines the validation process for a single epoch.
|
||||
Should be implemented with the actual model validation steps.
|
||||
|
||||
Args:
|
||||
epoch (int): The current epoch number.
|
||||
"""
|
||||
if self.use_ddp or self.use_fsdp:
|
||||
dist.barrier()
|
||||
logging.info(f"Validate epoch: {epoch}, rank: {self.rank}\n")
|
||||
model.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
|
||||
speed_stats = {}
|
||||
time5 = time.perf_counter()
|
||||
iterator_stop = torch.tensor(0).to(self.device)
|
||||
dataloader_val.batch_sampler.set_epoch(epoch)
|
||||
for batch_idx, batch in enumerate(dataloader_val):
|
||||
if self.use_ddp or self.use_fsdp:
|
||||
dist.all_reduce(iterator_stop, dist.ReduceOp.SUM)
|
||||
if iterator_stop > 0:
|
||||
break
|
||||
time1 = time.perf_counter()
|
||||
speed_stats["data_load"] = f"{time1 - time5:0.3f}"
|
||||
batch = to_device(batch, self.device)
|
||||
|
||||
time2 = time.perf_counter()
|
||||
retval = model(**batch)
|
||||
time3 = time.perf_counter()
|
||||
speed_stats["forward_time"] = f"{time3 - time2:0.3f}"
|
||||
loss, stats, weight = retval
|
||||
stats = {k: v for k, v in stats.items() if v is not None}
|
||||
|
||||
if self.use_ddp or self.use_fsdp:
|
||||
# Apply weighted averaging for loss and stats
|
||||
loss = (loss * weight.type(loss.dtype)).sum()
|
||||
# if distributed, this method can also apply all_reduce()
|
||||
# stats, weight = recursive_average(stats, weight, distributed=True)
|
||||
if self.use_ddp or self.use_fsdp:
|
||||
dist.all_reduce(weight, op=dist.ReduceOp.SUM)
|
||||
# Now weight is summation over all workers
|
||||
loss /= weight.sum() # shape:[1] -> shape:[]
|
||||
# Multiply world_size because DistributedDataParallel
|
||||
# automatically normalizes the gradient by world_size.
|
||||
loss *= self.world_size
|
||||
|
||||
# Scale the loss since we're not updating for every mini-batch
|
||||
loss = loss
|
||||
time4 = time.perf_counter()
|
||||
|
||||
if torch.isfinite(loss):
|
||||
self.val_loss_avg = (self.val_loss_avg * batch_idx + loss.detach().cpu().item()) / (
|
||||
batch_idx + 1
|
||||
)
|
||||
|
||||
if "acc" in stats:
|
||||
self.val_acc_avg = (
|
||||
self.val_acc_avg * batch_idx + stats["acc"].detach().cpu().item()
|
||||
) / (batch_idx + 1)
|
||||
|
||||
if self.use_ddp or self.use_fsdp:
|
||||
val_loss_avg = torch.tensor(self.val_loss_avg, dtype=torch.float32).to(
|
||||
self.device
|
||||
)
|
||||
val_acc_avg = torch.tensor(self.val_acc_avg, dtype=torch.float32).to(
|
||||
self.device
|
||||
)
|
||||
dist.all_reduce(val_loss_avg, op=dist.ReduceOp.SUM)
|
||||
dist.all_reduce(val_acc_avg, op=dist.ReduceOp.SUM)
|
||||
self.val_loss_avg = val_loss_avg.detach().cpu().item() / self.world_size
|
||||
self.val_acc_avg = val_acc_avg.detach().cpu().item() / self.world_size
|
||||
|
||||
time5 = time.perf_counter()
|
||||
batch_num_epoch = 1
|
||||
if hasattr(dataloader_val, "__len__"):
|
||||
batch_num_epoch = len(dataloader_val)
|
||||
self.log(
|
||||
epoch,
|
||||
batch_idx,
|
||||
batch_num_epoch=batch_num_epoch,
|
||||
lr=0.0,
|
||||
loss=loss.detach().cpu().item(),
|
||||
speed_stats=speed_stats,
|
||||
stats=stats,
|
||||
writer=writer,
|
||||
tag="val",
|
||||
)
|
||||
|
||||
else:
|
||||
if self.use_ddp or self.use_fsdp:
|
||||
iterator_stop.fill_(1)
|
||||
dist.all_reduce(iterator_stop, dist.ReduceOp.SUM)
|
||||
|
||||
if kwargs.get("step_in_epoch", None) is None:
|
||||
ckpt_name = f"model.pt.ep{epoch}"
|
||||
else:
|
||||
ckpt_name = f'model.pt.ep{epoch}.{kwargs.get("step_in_epoch")}'
|
||||
self.val_acc_step_or_epoch[ckpt_name] = self.val_acc_avg
|
||||
self.val_loss_step_or_epoch[ckpt_name] = self.val_loss_avg
|
||||
model.train()
|
||||
|
||||
if self.use_ddp or self.use_fsdp:
|
||||
dist.barrier()
|
||||
iterator_stop = torch.tensor(0).to(self.device)
|
||||
|
||||
def log(
|
||||
self,
|
||||
epoch=0,
|
||||
batch_idx=0,
|
||||
step_in_epoch=0,
|
||||
batch_num_epoch=-1,
|
||||
lr=0.0,
|
||||
loss=0.0,
|
||||
speed_stats=None,
|
||||
stats=None,
|
||||
writer=None,
|
||||
tag="train",
|
||||
data_split_i=0,
|
||||
data_split_num=1,
|
||||
log_step=None,
|
||||
**kwargs,
|
||||
):
|
||||
|
||||
if (batch_idx + 1) % self.log_interval == 0:
|
||||
batch_idx = log_step if log_step is not None else batch_idx
|
||||
gpu_info = (
|
||||
"GPU, memory: usage: {:.3f} GB, "
|
||||
"peak: {:.3f} GB, "
|
||||
"cache: {:.3f} GB, "
|
||||
"cache_peak: {:.3f} GB".format(
|
||||
torch.cuda.memory_allocated() / 1024 / 1024 / 1024,
|
||||
torch.cuda.max_memory_allocated() / 1024 / 1024 / 1024,
|
||||
torch.cuda.memory_reserved() / 1024 / 1024 / 1024,
|
||||
torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024,
|
||||
)
|
||||
)
|
||||
|
||||
loss_avg_epoch = getattr(self, f"{tag}_loss_avg")
|
||||
acc_avg_epoch = getattr(self, f"{tag}_acc_avg")
|
||||
description = (
|
||||
f"{tag}, "
|
||||
f"rank: {self.rank}, "
|
||||
f"epoch: {epoch}/{self.max_epoch}, "
|
||||
f"data_slice: {data_split_i}/{data_split_num}, "
|
||||
f"step_in_slice: {batch_idx + 1}/{batch_num_epoch}, step_in_epoch: {step_in_epoch}, total step: {self.batch_total}, "
|
||||
f"(loss_avg_rank: {loss:.3f}), "
|
||||
f"(loss_avg_slice: {loss_avg_epoch:.3f}), "
|
||||
f"(ppl_avg_slice: {math.exp(loss_avg_epoch):.3e}), "
|
||||
f"(acc_avg_slice: {acc_avg_epoch:.3f}), "
|
||||
f"(lr: {lr:.3e}), "
|
||||
f"{[(k, round(v.detach().cpu().item(), 3)) for k, v in stats.items()]}, "
|
||||
f"{speed_stats}, "
|
||||
f"{gpu_info}"
|
||||
)
|
||||
logging.info(description)
|
||||
|
||||
description_dict = {
|
||||
f"rank{self.rank}_loss/{tag}": loss,
|
||||
f"rank{self.rank}_lr/{tag}": lr,
|
||||
}
|
||||
|
||||
if writer is not None:
|
||||
writer.add_scalar(f"rank{self.rank}_loss/{tag}", loss, self.batch_total)
|
||||
writer.add_scalar(f"rank{self.rank}_lr/{tag}", lr, self.batch_total)
|
||||
for key, var in stats.items():
|
||||
writer.add_scalar(
|
||||
f"stats_rank{self.rank}_{key}/{tag}", var.item(), self.batch_total
|
||||
)
|
||||
description_dict[f"stats_rank{self.rank}_{key}/{tag}"] = var.item()
|
||||
for key, var in speed_stats.items():
|
||||
writer.add_scalar(
|
||||
f"stats_rank{self.rank}_{key}/{tag}", eval(var), self.batch_total
|
||||
)
|
||||
description_dict[f"stats_rank{self.rank}_{key}/{tag}"] = eval(var)
|
||||
if self.use_wandb and wandb is not None:
|
||||
wandb.log(
|
||||
description_dict,
|
||||
setp=self.batch_total,
|
||||
)
|
||||
|
||||
def close(self, writer=None):
|
||||
|
||||
if self.use_ddp or self.use_fsdp:
|
||||
dist.barrier()
|
||||
|
||||
if writer is not None:
|
||||
writer.close()
|
||||
|
||||
if self.use_ddp or self.use_fsdp:
|
||||
torch.distributed.destroy_process_group()
|
||||
997
modules/python/vendors/FunASR/funasr/train_utils/trainer_ds.py
vendored
Normal file
997
modules/python/vendors/FunASR/funasr/train_utils/trainer_ds.py
vendored
Normal file
@@ -0,0 +1,997 @@
|
||||
import math
|
||||
import os
|
||||
import time
|
||||
import torch
|
||||
import logging
|
||||
from tqdm import tqdm
|
||||
from datetime import datetime
|
||||
import torch.distributed as dist
|
||||
from torch.cuda.amp import autocast, GradScaler
|
||||
from contextlib import nullcontext, contextmanager
|
||||
from pathlib import Path
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
||||
from funasr.train_utils.device_funcs import to_device
|
||||
from funasr.train_utils.recursive_op import recursive_average
|
||||
from funasr.train_utils.average_nbest_models import average_checkpoints
|
||||
from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
|
||||
import funasr.utils.misc as misc_utils
|
||||
|
||||
try:
|
||||
import wandb
|
||||
except:
|
||||
wandb = None
|
||||
|
||||
|
||||
@contextmanager
|
||||
def maybe_autocast(dtype=None, use_deepspeed=False):
|
||||
if use_deepspeed:
|
||||
with torch.cuda.amp.autocast(enabled=True, dtype=dtype, cache_enabled=False):
|
||||
yield
|
||||
else:
|
||||
if dtype == torch.float16 or dtype == torch.bfloat16:
|
||||
yield
|
||||
# with autocast(enabled=True, dtype=dtype):
|
||||
# yield
|
||||
else:
|
||||
yield
|
||||
|
||||
|
||||
class Trainer:
|
||||
"""
|
||||
A simple trainer class for training a PyTorch model, saving checkpoints at the end of each epoch,
|
||||
and optionally resuming from a saved checkpoint.
|
||||
|
||||
Attributes:
|
||||
max_epoch (int): Maximum number of epochs for training.
|
||||
model (torch.nn.Module): The model to be trained.
|
||||
optim (torch.optim.Optimizer): The optimizer to use for training.
|
||||
scheduler (torch.optim.lr_scheduler._LRScheduler): The learning rate scheduler.
|
||||
dataloader_train (torch.utils.data.DataLoader): DataLoader for the training dataset.
|
||||
dataloader_val (torch.utils.data.DataLoader): DataLoader for the validation dataset.
|
||||
output_dir (str): Directory where model checkpoints will be saved.
|
||||
resume (str, optional): Path to a checkpoint to resume training from.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
rank=0,
|
||||
local_rank=0,
|
||||
world_size=1,
|
||||
use_ddp: bool = False,
|
||||
use_fsdp: bool = False,
|
||||
use_fp16: bool = False,
|
||||
use_bf16: bool = False,
|
||||
use_deepspeed: bool = False,
|
||||
output_dir: str = "./",
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Initializes the Trainer class with the model, optimizer, scheduler, dataloaders, and other settings.
|
||||
|
||||
Args:
|
||||
model (torch.nn.Module): The model to be trained.
|
||||
optim (torch.optim.Optimizer): The optimizer to use for training.
|
||||
scheduler (torch.optim.lr_scheduler._LRScheduler): The learning rate scheduler.
|
||||
dataloader_train (torch.utils.data.DataLoader): The DataLoader for the training dataset.
|
||||
dataloader_val (torch.utils.data.DataLoader): The DataLoader for the validation dataset.
|
||||
**kwargs: Additional keyword arguments:
|
||||
max_epoch (int): The maximum number of epochs for training.
|
||||
output_dir (str): The directory where model checkpoints will be saved. Default is './'.
|
||||
resume (str, optional): The file path to a checkpoint to resume training from.
|
||||
"""
|
||||
self.rank = rank
|
||||
self.local_rank = local_rank
|
||||
self.world_size = world_size
|
||||
self.use_ddp = use_ddp
|
||||
self.use_fsdp = use_fsdp
|
||||
|
||||
self.device = kwargs.get("device", "cuda")
|
||||
|
||||
self.output_dir = output_dir
|
||||
if not os.path.exists(self.output_dir):
|
||||
os.makedirs(self.output_dir, exist_ok=True)
|
||||
self.resume = kwargs.get("resume", True)
|
||||
self.start_epoch = 0
|
||||
self.max_epoch = kwargs.get("max_epoch", 100)
|
||||
|
||||
# self.kwargs = kwargs
|
||||
self.log_interval = kwargs.get("log_interval", 50)
|
||||
self.batch_total = 0
|
||||
self.dtype = torch.float32
|
||||
self.use_fp16 = use_fp16
|
||||
self.use_bf16 = use_bf16
|
||||
if self.use_fp16:
|
||||
self.dtype = torch.float16
|
||||
if self.use_bf16:
|
||||
self.dtype = torch.bfloat16
|
||||
self.save_checkpoint_interval = kwargs.get("save_checkpoint_interval", 5000)
|
||||
self.validate_interval = kwargs.get("validate_interval", 5000)
|
||||
self.keep_nbest_models = kwargs.get("keep_nbest_models", 500)
|
||||
self.avg_keep_nbest_models_type = kwargs.get("avg_keep_nbest_models_type", "acc")
|
||||
self.avg_nbest_model = kwargs.get("avg_nbest_model", 10)
|
||||
self.accum_grad = kwargs.get("accum_grad", 1)
|
||||
self.grad_clip = kwargs.get("grad_clip", 10.0)
|
||||
self.grad_clip_type = kwargs.get("grad_clip_type", 2.0)
|
||||
|
||||
self.train_acc_avg = 0.0
|
||||
self.train_loss_avg = 0.0
|
||||
self.val_acc_avg = 0.0
|
||||
self.val_loss_avg = 0.0
|
||||
self.best_acc_idx = 0
|
||||
self.saved_ckpts = {}
|
||||
self.step_or_epoch = -1
|
||||
self.best_step_or_epoch = ""
|
||||
self.val_acc_step_or_eoch = {}
|
||||
self.val_loss_step_or_eoch = {}
|
||||
|
||||
self.reset_gpu_cache = kwargs.get("reset_gpu_cache", False)
|
||||
self.start_data_split_i = 0
|
||||
self.start_step = 0
|
||||
self.step_in_epoch = 0
|
||||
self.use_wandb = kwargs.get("use_wandb", False)
|
||||
if self.use_wandb:
|
||||
wandb.login(key=kwargs.get("wandb_token"))
|
||||
wandb.init(
|
||||
config=kwargs,
|
||||
project=kwargs.get("wandb_project", "my_project"),
|
||||
entity=kwargs.get("wandb_team", "my_team"),
|
||||
name=kwargs.get("wandb_exp_name", "my_exp"),
|
||||
dir=output_dir,
|
||||
job_type="training",
|
||||
reinit=True,
|
||||
)
|
||||
tensorboard_dir = os.path.join(output_dir, "tensorboard")
|
||||
os.makedirs(tensorboard_dir, exist_ok=True)
|
||||
try:
|
||||
from tensorboardX import SummaryWriter
|
||||
|
||||
self.writer = SummaryWriter(tensorboard_dir) # if trainer.rank == 0 else None
|
||||
except:
|
||||
self.writer = None
|
||||
|
||||
self.use_deepspeed = use_deepspeed
|
||||
self.deepspeed_config = kwargs.get("deepspeed_config", "")
|
||||
excludes = kwargs.get("excludes", None)
|
||||
if excludes is not None:
|
||||
if isinstance(excludes, str):
|
||||
excludes = excludes.split(",")
|
||||
self.excludes = excludes
|
||||
effective_save_name_excludes = kwargs.get("effective_save_name_excludes", None)
|
||||
if effective_save_name_excludes is not None:
|
||||
if isinstance(effective_save_name_excludes, str):
|
||||
effective_save_name_excludes = effective_save_name_excludes.split(",")
|
||||
self.effective_save_name_excludes = effective_save_name_excludes
|
||||
|
||||
def save_checkpoint(
|
||||
self,
|
||||
epoch,
|
||||
step=None,
|
||||
model=None,
|
||||
optim=None,
|
||||
scheduler=None,
|
||||
scaler=None,
|
||||
step_in_epoch=None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Saves a checkpoint containing the model's state, the optimizer's state,
|
||||
and the scheduler's state at the end of the given epoch. This method is
|
||||
intended to be called at the end of each epoch to save the training progress.
|
||||
|
||||
Args:
|
||||
epoch (int): The epoch number at which the checkpoint is being saved.
|
||||
"""
|
||||
if self.use_ddp or self.use_fsdp:
|
||||
dist.barrier()
|
||||
step_in_epoch = None if step is None else step_in_epoch
|
||||
if self.use_deepspeed:
|
||||
|
||||
logging.info(f"Save checkpoint: {epoch}, rank: {self.local_rank}\n")
|
||||
# self.step_or_epoch += 1
|
||||
state = {
|
||||
"epoch": epoch,
|
||||
# "state_dict": model.state_dict(),
|
||||
# "optimizer": optim.state_dict(),
|
||||
# "scheduler": scheduler.state_dict(),
|
||||
"saved_ckpts": self.saved_ckpts,
|
||||
"val_acc_step_or_eoch": self.val_acc_step_or_eoch,
|
||||
"val_loss_step_or_eoch": self.val_loss_step_or_eoch,
|
||||
"best_step_or_epoch": self.best_step_or_epoch,
|
||||
"avg_keep_nbest_models_type": self.avg_keep_nbest_models_type,
|
||||
"step": step,
|
||||
"step_in_epoch": step_in_epoch,
|
||||
"data_split_i": kwargs.get("data_split_i", 0),
|
||||
"data_split_num": kwargs.get("data_split_num", 1),
|
||||
"batch_total": self.batch_total,
|
||||
"train_loss_avg": kwargs.get("train_loss_avg", 0),
|
||||
"train_acc_avg": kwargs.get("train_acc_avg", 0),
|
||||
}
|
||||
step = step_in_epoch
|
||||
if hasattr(model, "module"):
|
||||
state["state_dict"] = model.module.state_dict()
|
||||
|
||||
if scaler:
|
||||
state["scaler_state"] = scaler.state_dict()
|
||||
# Create output directory if it does not exist
|
||||
os.makedirs(self.output_dir, exist_ok=True)
|
||||
if step is None:
|
||||
ckpt_name = f"model.pt.ep{epoch}"
|
||||
else:
|
||||
ckpt_name = f"model.pt.ep{epoch}.{step}"
|
||||
filename = os.path.join(self.output_dir, ckpt_name)
|
||||
|
||||
# torch.save(state, filename)
|
||||
with torch.no_grad():
|
||||
model.save_checkpoint(save_dir=self.output_dir, tag=ckpt_name, client_state=state)
|
||||
logging.info(f"\nCheckpoint saved to {filename}\n")
|
||||
latest = Path(os.path.join(self.output_dir, f"model.pt"))
|
||||
# torch.save(state, latest)
|
||||
with torch.no_grad():
|
||||
model.save_checkpoint(save_dir=self.output_dir, tag=f"model.pt", client_state=state)
|
||||
if self.best_step_or_epoch == "":
|
||||
self.best_step_or_epoch = ckpt_name
|
||||
|
||||
if self.avg_keep_nbest_models_type == "acc":
|
||||
if (
|
||||
self.val_acc_step_or_eoch[ckpt_name]
|
||||
>= self.val_acc_step_or_eoch[self.best_step_or_epoch]
|
||||
):
|
||||
self.best_step_or_epoch = ckpt_name
|
||||
best_ckpt = Path(os.path.join(self.output_dir, f"model.pt.best"))
|
||||
# torch.save(state, best_ckpt)
|
||||
with torch.no_grad():
|
||||
model.save_checkpoint(
|
||||
save_dir=self.output_dir, tag=f"model.pt.best", client_state=state
|
||||
)
|
||||
logging.info(
|
||||
f"Update best acc: {self.val_acc_step_or_eoch[self.best_step_or_epoch]:.4f}, {best_ckpt}"
|
||||
)
|
||||
else:
|
||||
logging.info(
|
||||
f"No improvement in acc: {self.val_acc_step_or_eoch[ckpt_name]:.4f} < {self.val_acc_step_or_eoch[self.best_step_or_epoch]:.4f}, {os.path.join(self.output_dir, self.best_step_or_epoch)}"
|
||||
)
|
||||
elif self.avg_keep_nbest_models_type == "loss":
|
||||
if (
|
||||
self.val_loss_step_or_eoch[ckpt_name]
|
||||
<= self.val_loss_step_or_eoch[self.best_step_or_epoch]
|
||||
):
|
||||
self.best_step_or_epoch = ckpt_name
|
||||
best_ckpt = Path(os.path.join(self.output_dir, f"model.pt.best"))
|
||||
# torch.save(state, best_ckpt)
|
||||
with torch.no_grad():
|
||||
model.save_checkpoint(
|
||||
save_dir=self.output_dir, tag=f"model.pt.best", client_state=state
|
||||
)
|
||||
logging.info(
|
||||
f"Update best loss: {self.val_loss_step_or_eoch[self.best_step_or_epoch]:.4f}, {best_ckpt}"
|
||||
)
|
||||
else:
|
||||
logging.info(
|
||||
f"No improvement in loss: {self.val_loss_step_or_eoch[ckpt_name]:.4f} > {self.val_loss_step_or_eoch[self.best_step_or_epoch]:.4f}, {os.path.join(self.output_dir, self.best_step_or_epoch)}"
|
||||
)
|
||||
else:
|
||||
print("Undo")
|
||||
self.saved_ckpts[ckpt_name] = getattr(
|
||||
self, f"val_{self.avg_keep_nbest_models_type}_step_or_eoch"
|
||||
)[ckpt_name]
|
||||
if self.keep_nbest_models > 0:
|
||||
if len(self.saved_ckpts) > self.keep_nbest_models:
|
||||
if self.avg_keep_nbest_models_type == "acc":
|
||||
key = min(self.saved_ckpts, key=self.saved_ckpts.get)
|
||||
else:
|
||||
key = max(self.saved_ckpts, key=self.saved_ckpts.get)
|
||||
if key in self.saved_ckpts:
|
||||
del self.saved_ckpts[key]
|
||||
filename = os.path.join(self.output_dir, key)
|
||||
logging.info(f"Delete: {filename}")
|
||||
if os.path.exists(filename):
|
||||
# os.remove(filename)
|
||||
misc_utils.smart_remove(filename)
|
||||
|
||||
elif self.use_fsdp:
|
||||
pass
|
||||
elif self.rank == 0:
|
||||
logging.info(
|
||||
f"Save checkpoint: {epoch}, rank: {self.rank}, local_rank: {self.local_rank}\n"
|
||||
)
|
||||
# self.step_or_epoch += 1
|
||||
state = {
|
||||
"epoch": epoch,
|
||||
"optimizer": optim.state_dict(),
|
||||
"scheduler": scheduler.state_dict(),
|
||||
"saved_ckpts": self.saved_ckpts,
|
||||
"val_acc_step_or_eoch": self.val_acc_step_or_eoch,
|
||||
"val_loss_step_or_eoch": self.val_loss_step_or_eoch,
|
||||
"best_step_or_epoch": self.best_step_or_epoch,
|
||||
"avg_keep_nbest_models_type": self.avg_keep_nbest_models_type,
|
||||
"step": step,
|
||||
"step_in_epoch": step_in_epoch,
|
||||
"data_split_i": kwargs.get("data_split_i", 0),
|
||||
"data_split_num": kwargs.get("data_split_num", 1),
|
||||
"batch_total": self.batch_total,
|
||||
"train_loss_avg": kwargs.get("train_loss_avg", 0),
|
||||
"train_acc_avg": kwargs.get("train_acc_avg", 0),
|
||||
}
|
||||
step = step_in_epoch
|
||||
if hasattr(model, "module"):
|
||||
state_dict = model.module.state_dict()
|
||||
else:
|
||||
state_dict = model.state_dict()
|
||||
|
||||
if self.effective_save_name_excludes is not None:
|
||||
logging.info(f"effective_save_name_excludes: {self.effective_save_name_excludes}")
|
||||
dst_state_dict = {}
|
||||
for k in state_dict.keys():
|
||||
for k_ex in self.effective_save_name_excludes:
|
||||
k_tmp = k.replace("module.", "")
|
||||
if k.startswith(k_ex):
|
||||
logging.info(f"key: {k} matching: {k_ex}, not save it")
|
||||
break
|
||||
else:
|
||||
dst_state_dict[k] = state_dict[k]
|
||||
state["state_dict"] = dst_state_dict
|
||||
else:
|
||||
state["state_dict"] = state_dict
|
||||
|
||||
if scaler:
|
||||
state["scaler_state"] = scaler.state_dict()
|
||||
# Create output directory if it does not exist
|
||||
os.makedirs(self.output_dir, exist_ok=True)
|
||||
if step is None:
|
||||
ckpt_name = f"model.pt.ep{epoch}"
|
||||
else:
|
||||
ckpt_name = f"model.pt.ep{epoch}.{step}"
|
||||
filename = os.path.join(self.output_dir, ckpt_name)
|
||||
torch.save(state, filename)
|
||||
|
||||
logging.info(f"\nCheckpoint saved to {filename}\n")
|
||||
latest = Path(os.path.join(self.output_dir, f"model.pt"))
|
||||
torch.save(state, latest)
|
||||
if self.best_step_or_epoch == "":
|
||||
self.best_step_or_epoch = ckpt_name
|
||||
|
||||
if self.avg_keep_nbest_models_type == "acc":
|
||||
if (
|
||||
self.val_acc_step_or_eoch[ckpt_name]
|
||||
>= self.val_acc_step_or_eoch[self.best_step_or_epoch]
|
||||
):
|
||||
self.best_step_or_epoch = ckpt_name
|
||||
best_ckpt = Path(os.path.join(self.output_dir, f"model.pt.best"))
|
||||
torch.save(state, best_ckpt)
|
||||
logging.info(
|
||||
f"Update best acc: {self.val_acc_step_or_eoch[self.best_step_or_epoch]:.4f}, {best_ckpt}"
|
||||
)
|
||||
else:
|
||||
logging.info(
|
||||
f"No improvement in acc: {self.val_acc_step_or_eoch[ckpt_name]:.4f} < {self.val_acc_step_or_eoch[self.best_step_or_epoch]:.4f}, {os.path.join(self.output_dir, self.best_step_or_epoch)}"
|
||||
)
|
||||
elif self.avg_keep_nbest_models_type == "loss":
|
||||
if (
|
||||
self.val_loss_step_or_eoch[ckpt_name]
|
||||
<= self.val_loss_step_or_eoch[self.best_step_or_epoch]
|
||||
):
|
||||
self.best_step_or_epoch = ckpt_name
|
||||
best_ckpt = Path(os.path.join(self.output_dir, f"model.pt.best"))
|
||||
torch.save(state, best_ckpt)
|
||||
logging.info(
|
||||
f"Update best loss: {self.val_loss_step_or_eoch[self.best_step_or_epoch]:.4f}, {best_ckpt}"
|
||||
)
|
||||
else:
|
||||
logging.info(
|
||||
f"No improvement in loss: {self.val_loss_step_or_eoch[ckpt_name]:.4f} > {self.val_loss_step_or_eoch[self.best_step_or_epoch]:.4f}, {os.path.join(self.output_dir, self.best_step_or_epoch)}"
|
||||
)
|
||||
else:
|
||||
print("Undo")
|
||||
self.saved_ckpts[ckpt_name] = getattr(
|
||||
self, f"val_{self.avg_keep_nbest_models_type}_step_or_eoch"
|
||||
)[ckpt_name]
|
||||
if self.keep_nbest_models > 0:
|
||||
if len(self.saved_ckpts) > self.keep_nbest_models:
|
||||
if self.avg_keep_nbest_models_type == "acc":
|
||||
key = min(self.saved_ckpts, key=self.saved_ckpts.get)
|
||||
else:
|
||||
key = max(self.saved_ckpts, key=self.saved_ckpts.get)
|
||||
if key in self.saved_ckpts:
|
||||
del self.saved_ckpts[key]
|
||||
filename = os.path.join(self.output_dir, key)
|
||||
logging.info(f"Delete: {filename}")
|
||||
if os.path.exists(filename):
|
||||
# os.remove(filename)
|
||||
misc_utils.smart_remove(filename)
|
||||
|
||||
if self.use_ddp or self.use_fsdp:
|
||||
dist.barrier()
|
||||
|
||||
def resume_checkpoint(
|
||||
self,
|
||||
model=None,
|
||||
optim=None,
|
||||
scheduler=None,
|
||||
scaler=None,
|
||||
):
|
||||
"""
|
||||
Resumes training from a checkpoint at the given file path.
|
||||
Loads the model's state, the optimizer's state, and the scheduler's state.
|
||||
|
||||
Args:
|
||||
resume_path (str): The file path to the checkpoint to resume from.
|
||||
"""
|
||||
if self.resume:
|
||||
|
||||
if self.use_deepspeed:
|
||||
ckpt = os.path.join(self.output_dir, "model.pt")
|
||||
if os.path.exists(ckpt):
|
||||
_, checkpoint = model.load_checkpoint(self.output_dir, "model.pt")
|
||||
self.start_epoch = checkpoint["epoch"]
|
||||
self.saved_ckpts = checkpoint["saved_ckpts"]
|
||||
self.val_acc_step_or_eoch = (
|
||||
checkpoint["val_acc_step_or_eoch"]
|
||||
if "val_acc_step_or_eoch" in checkpoint
|
||||
else {}
|
||||
)
|
||||
self.val_loss_step_or_eoch = (
|
||||
checkpoint["val_loss_step_or_eoch"]
|
||||
if "val_loss_step_or_eoch" in checkpoint
|
||||
else {}
|
||||
)
|
||||
self.best_step_or_epoch = (
|
||||
checkpoint["best_step_or_epoch"]
|
||||
if "best_step_or_epoch" in checkpoint
|
||||
else ""
|
||||
)
|
||||
self.start_data_split_i = (
|
||||
checkpoint["data_split_i"] if "data_split_i" in checkpoint else 0
|
||||
)
|
||||
self.batch_total = (
|
||||
checkpoint["batch_total"] if "batch_total" in checkpoint else 0
|
||||
)
|
||||
self.start_step = checkpoint["step"] if "step" in checkpoint else 0
|
||||
self.start_step = 0 if self.start_step is None else self.start_step
|
||||
self.step_in_epoch = (
|
||||
checkpoint["step_in_epoch"] if "step_in_epoch" in checkpoint else 0
|
||||
)
|
||||
self.step_in_epoch = 0 if self.step_in_epoch is None else self.step_in_epoch
|
||||
print(checkpoint["train_acc_avg"])
|
||||
self.train_acc_avg = (
|
||||
checkpoint["train_acc_avg"] if "train_acc_avg" in checkpoint else 0
|
||||
)
|
||||
self.train_loss_avg = (
|
||||
checkpoint["train_loss_avg"] if "train_loss_avg" in checkpoint else 0
|
||||
)
|
||||
model.to(self.device)
|
||||
print(f"Checkpoint loaded successfully from '{ckpt}'")
|
||||
else:
|
||||
print(f"No checkpoint found at '{ckpt}', does not resume status!")
|
||||
else:
|
||||
|
||||
ckpt = os.path.join(self.output_dir, "model.pt")
|
||||
if os.path.isfile(ckpt):
|
||||
checkpoint = torch.load(ckpt, map_location="cpu")
|
||||
self.start_epoch = checkpoint["epoch"]
|
||||
# self.model.load_state_dict(checkpoint['state_dict'])
|
||||
src_state = checkpoint["state_dict"]
|
||||
dst_state = model.state_dict()
|
||||
for k in dst_state.keys():
|
||||
excludes_flag = False
|
||||
if self.excludes is not None:
|
||||
for k_ex in self.excludes:
|
||||
k_tmp = k.replace("module.", "")
|
||||
if k_tmp.startswith(k_ex):
|
||||
logging.info(f"key: {k} matching: {k_ex}, excluded")
|
||||
excludes_flag = True
|
||||
break
|
||||
if excludes_flag:
|
||||
continue
|
||||
if not k.startswith("module.") and "module." + k in src_state.keys():
|
||||
k_ddp = "module." + k
|
||||
elif k.startswith("module.") and "module." + k not in src_state.keys():
|
||||
k_ddp = k.replace("module.", "", 1)
|
||||
else:
|
||||
k_ddp = k
|
||||
if k_ddp in src_state.keys():
|
||||
dst_state[k] = src_state[k_ddp]
|
||||
else:
|
||||
print(f"Miss key in ckpt: model: {k}, ckpt: {k_ddp}")
|
||||
|
||||
model.load_state_dict(dst_state)
|
||||
optim.load_state_dict(checkpoint["optimizer"])
|
||||
scheduler.load_state_dict(checkpoint["scheduler"])
|
||||
if scaler is not None and "scaler_state" in checkpoint:
|
||||
scaler.load_state_dict(checkpoint["scaler_state"])
|
||||
|
||||
self.saved_ckpts = checkpoint["saved_ckpts"]
|
||||
self.val_acc_step_or_eoch = (
|
||||
checkpoint["val_acc_step_or_eoch"]
|
||||
if "val_acc_step_or_eoch" in checkpoint
|
||||
else {}
|
||||
)
|
||||
self.val_loss_step_or_eoch = (
|
||||
checkpoint["val_loss_step_or_eoch"]
|
||||
if "val_loss_step_or_eoch" in checkpoint
|
||||
else {}
|
||||
)
|
||||
self.best_step_or_epoch = (
|
||||
checkpoint["best_step_or_epoch"]
|
||||
if "best_step_or_epoch" in checkpoint
|
||||
else ""
|
||||
)
|
||||
self.start_data_split_i = (
|
||||
checkpoint["data_split_i"] if "data_split_i" in checkpoint else 0
|
||||
)
|
||||
self.batch_total = (
|
||||
checkpoint["batch_total"] if "batch_total" in checkpoint else 0
|
||||
)
|
||||
self.start_step = checkpoint["step"] if "step" in checkpoint else 0
|
||||
self.start_step = 0 if self.start_step is None else self.start_step
|
||||
self.step_in_epoch = (
|
||||
checkpoint["step_in_epoch"] if "step_in_epoch" in checkpoint else 0
|
||||
)
|
||||
self.step_in_epoch = 0 if self.step_in_epoch is None else self.step_in_epoch
|
||||
print(checkpoint["train_acc_avg"])
|
||||
self.train_acc_avg = (
|
||||
checkpoint["train_acc_avg"] if "train_acc_avg" in checkpoint else 0
|
||||
)
|
||||
self.train_loss_avg = (
|
||||
checkpoint["train_loss_avg"] if "train_loss_avg" in checkpoint else 0
|
||||
)
|
||||
model.to(self.device)
|
||||
print(f"Checkpoint loaded successfully from '{ckpt}'")
|
||||
else:
|
||||
print(f"No checkpoint found at '{ckpt}', does not resume status!")
|
||||
|
||||
if self.use_ddp or self.use_fsdp:
|
||||
dist.barrier()
|
||||
|
||||
def train_epoch(
|
||||
self,
|
||||
model=None,
|
||||
optim=None,
|
||||
scheduler=None,
|
||||
scaler=None,
|
||||
dataloader_train=None,
|
||||
dataloader_val=None,
|
||||
epoch=None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Defines the training process for a single epoch with gradient accumulation.
|
||||
Args:
|
||||
epoch (int): The current epoch number.
|
||||
"""
|
||||
if self.use_ddp or self.use_fsdp or self.use_deepspeed:
|
||||
dist.barrier()
|
||||
logging.info(f"Train epoch: {epoch}, rank: {self.rank}\n")
|
||||
model.train()
|
||||
|
||||
# Set the number of steps for gradient accumulation
|
||||
accum_grad = self.accum_grad
|
||||
# Initialize the gradient accumulation
|
||||
optim.zero_grad()
|
||||
speed_stats = {}
|
||||
|
||||
iterator_stop = torch.tensor(0).to(self.device)
|
||||
|
||||
dataloader_train.batch_sampler.set_epoch(epoch)
|
||||
time_beg = time.perf_counter()
|
||||
time5 = time_beg
|
||||
for batch_idx, batch in enumerate(dataloader_train):
|
||||
self.batch_total += 1
|
||||
self.step_in_epoch += 1
|
||||
loss_dict = {
|
||||
"speed_stats": {},
|
||||
"epoch": epoch,
|
||||
"batch_idx": batch_idx,
|
||||
"data_split_i": kwargs.get("data_split_i", 0),
|
||||
"data_split_num": kwargs.get("data_split_num", 1),
|
||||
"log_step": batch_idx + kwargs.get("start_step", 0),
|
||||
"batch_total": self.batch_total,
|
||||
"step_in_epoch": self.step_in_epoch,
|
||||
}
|
||||
|
||||
time1 = time.perf_counter()
|
||||
loss_dict["speed_stats"]["data_load"] = f"{time1-time_beg:0.3f}"
|
||||
|
||||
batch = to_device(batch, self.device)
|
||||
|
||||
my_context = nullcontext
|
||||
if self.use_ddp or self.use_fsdp:
|
||||
my_context = model.no_sync if batch_idx % accum_grad != 0 else my_context
|
||||
with my_context():
|
||||
time2 = time.perf_counter()
|
||||
|
||||
self.forward_step(model, batch, loss_dict=loss_dict)
|
||||
|
||||
time3 = time.perf_counter()
|
||||
loss_dict["speed_stats"]["forward_time"] = f"{time3 - time2:0.3f}"
|
||||
self.backward_step(model, scaler, loss_dict=loss_dict)
|
||||
|
||||
time4 = time.perf_counter()
|
||||
loss_dict["speed_stats"]["backward_time"] = f"{time4 - time3:0.3f}"
|
||||
|
||||
self.update_step(model, optim, scheduler, scaler, loss_dict=loss_dict)
|
||||
total_time = f"{(time.perf_counter() - time5):0.3f}"
|
||||
time5 = time.perf_counter()
|
||||
|
||||
loss_dict["speed_stats"]["optim_time"] = f"{time5 - time4:0.3f}"
|
||||
|
||||
loss_dict["speed_stats"]["total_time"] = total_time
|
||||
|
||||
loss_dict["lr"] = scheduler.get_last_lr()[0]
|
||||
loss_dict["batch_num_epoch"] = len(dataloader_train)
|
||||
|
||||
self.train_loss_avg = (
|
||||
self.train_loss_avg * batch_idx + loss_dict["loss"].detach().cpu().item()
|
||||
) / (batch_idx + 1)
|
||||
if "acc" in loss_dict["stats"]:
|
||||
self.train_acc_avg = (
|
||||
self.train_acc_avg * batch_idx + loss_dict["stats"]["acc"].detach().cpu().item()
|
||||
) / (batch_idx + 1)
|
||||
|
||||
self.log(loss_dict, tag="train")
|
||||
|
||||
if self.step_in_epoch % self.validate_interval == 0:
|
||||
self.validate_epoch(
|
||||
model=model,
|
||||
dataloader_val=dataloader_val,
|
||||
epoch=epoch,
|
||||
writer=self.writer,
|
||||
step=batch_idx + 1,
|
||||
step_in_epoch=self.step_in_epoch,
|
||||
)
|
||||
|
||||
if self.step_in_epoch % self.save_checkpoint_interval == 0:
|
||||
self.save_checkpoint(
|
||||
epoch,
|
||||
model=model,
|
||||
optim=optim,
|
||||
scheduler=scheduler,
|
||||
scaler=scaler,
|
||||
step=batch_idx + 1,
|
||||
step_in_epoch=self.step_in_epoch,
|
||||
data_split_i=kwargs.get("data_split_i", 0),
|
||||
data_split_num=kwargs.get("data_split_num", 1),
|
||||
train_loss_avg=self.train_loss_avg,
|
||||
train_acc_avg=self.train_acc_avg,
|
||||
)
|
||||
|
||||
time_beg = time.perf_counter()
|
||||
|
||||
if self.use_ddp or self.use_fsdp or self.use_deepspeed:
|
||||
train_loss_avg = torch.tensor(self.train_loss_avg, dtype=torch.float32).to(self.device)
|
||||
train_acc_avg = torch.tensor(self.train_acc_avg, dtype=torch.float32).to(self.device)
|
||||
dist.all_reduce(train_loss_avg, op=dist.ReduceOp.SUM)
|
||||
dist.all_reduce(train_acc_avg, op=dist.ReduceOp.SUM)
|
||||
self.train_loss_avg = train_loss_avg.detach().cpu().item() / self.world_size
|
||||
self.train_acc_avg = train_acc_avg.detach().cpu().item() / self.world_size
|
||||
|
||||
def forward_step(self, model, batch, loss_dict={}):
|
||||
with maybe_autocast(dtype=self.dtype, use_deepspeed=self.use_deepspeed):
|
||||
retval = model(**batch)
|
||||
|
||||
loss, stats, weight = retval
|
||||
stats = {k: v for k, v in stats.items() if v is not None}
|
||||
|
||||
loss_dict["loss"] = loss
|
||||
loss_dict["stats"] = stats
|
||||
loss_dict["weight"] = weight
|
||||
|
||||
def backward_step(self, model, scaler, loss_dict={}):
|
||||
loss = loss_dict["loss"]
|
||||
|
||||
if self.use_deepspeed:
|
||||
scaled_loss = model.backward(loss)
|
||||
else:
|
||||
loss = loss / self.accum_grad
|
||||
if self.use_fp16 or self.use_bf16:
|
||||
scaler.scale(loss).backward()
|
||||
else:
|
||||
loss.backward()
|
||||
|
||||
def update_step(self, model, optim, scheduler, scaler, loss_dict=None):
|
||||
batch_idx = loss_dict["batch_idx"]
|
||||
if self.use_deepspeed:
|
||||
model.step()
|
||||
else:
|
||||
if (batch_idx + 1) % self.accum_grad == 0:
|
||||
# Perform gradient clipping if it is set
|
||||
if self.grad_clip > 0:
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||
model.parameters(),
|
||||
max_norm=self.grad_clip,
|
||||
norm_type=self.grad_clip_type,
|
||||
)
|
||||
if not torch.isfinite(grad_norm):
|
||||
logging.warning(
|
||||
f"The grad norm is {grad_norm}. Skipping updating the model."
|
||||
)
|
||||
optim.zero_grad() # Reset gradients
|
||||
return
|
||||
|
||||
# Execute an optimization step (update model parameters)
|
||||
if self.use_ddp or self.use_fsdp:
|
||||
dist.barrier()
|
||||
if self.use_fp16 or self.use_bf16:
|
||||
scaler.step(optim)
|
||||
scaler.update()
|
||||
else:
|
||||
optim.step()
|
||||
scheduler.step()
|
||||
# Clear gradients for the next accumulation stage
|
||||
optim.zero_grad(set_to_none=True)
|
||||
|
||||
def validate_epoch(
|
||||
self,
|
||||
model=None,
|
||||
dataloader_val=None,
|
||||
epoch=None,
|
||||
writer=None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Defines the validation process for a single epoch.
|
||||
Should be implemented with the actual model validation steps.
|
||||
|
||||
Args:
|
||||
epoch (int): The current epoch number.
|
||||
"""
|
||||
if self.use_ddp or self.use_fsdp or self.use_deepspeed:
|
||||
dist.barrier()
|
||||
logging.info(f"Validate epoch: {epoch}, rank: {self.rank}\n")
|
||||
model.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
|
||||
speed_stats = {}
|
||||
time_beg = time.perf_counter()
|
||||
time5 = time_beg
|
||||
|
||||
dataloader_val.batch_sampler.set_epoch(epoch)
|
||||
for batch_idx, batch in enumerate(dataloader_val):
|
||||
|
||||
loss_dict = {
|
||||
"speed_stats": {},
|
||||
"epoch": epoch,
|
||||
"batch_idx": batch_idx,
|
||||
"data_split_i": kwargs.get("data_split_i", 0),
|
||||
"data_split_num": kwargs.get("data_split_num", 1),
|
||||
"log_step": batch_idx + kwargs.get("start_step", 0),
|
||||
"batch_total": batch_idx + 1,
|
||||
"step_in_epoch": batch_idx + 1,
|
||||
"lr": 0.0,
|
||||
}
|
||||
|
||||
time1 = time.perf_counter()
|
||||
loss_dict["speed_stats"]["data_load"] = f"{time1 - time_beg:0.3f}"
|
||||
|
||||
batch = to_device(batch, self.device)
|
||||
|
||||
time2 = time.perf_counter()
|
||||
|
||||
self.forward_step(model, batch, loss_dict=loss_dict)
|
||||
|
||||
time3 = time.perf_counter()
|
||||
loss_dict["speed_stats"]["forward_time"] = f"{time3 - time2:0.3f}"
|
||||
|
||||
total_time = f"{(time.perf_counter() - time5):0.3f}"
|
||||
time5 = time.perf_counter()
|
||||
|
||||
loss_dict["speed_stats"]["total_time"] = total_time
|
||||
|
||||
loss_dict["batch_num_epoch"] = len(dataloader_val)
|
||||
|
||||
self.log(loss_dict, tag="val")
|
||||
time_beg = time.perf_counter()
|
||||
self.val_loss_avg = (
|
||||
self.val_loss_avg * batch_idx + loss_dict["loss"].detach().cpu().item()
|
||||
) / (batch_idx + 1)
|
||||
if "acc" in loss_dict["stats"]:
|
||||
self.val_acc_avg = (
|
||||
self.val_acc_avg * batch_idx
|
||||
+ loss_dict["stats"]["acc"].detach().cpu().item()
|
||||
) / (batch_idx + 1)
|
||||
|
||||
if self.use_ddp or self.use_fsdp or self.use_deepspeed:
|
||||
val_loss_avg = torch.tensor(self.val_loss_avg, dtype=torch.float32).to(self.device)
|
||||
val_acc_avg = torch.tensor(self.val_acc_avg, dtype=torch.float32).to(self.device)
|
||||
dist.all_reduce(val_loss_avg, op=dist.ReduceOp.SUM)
|
||||
dist.all_reduce(val_acc_avg, op=dist.ReduceOp.SUM)
|
||||
self.val_loss_avg = val_loss_avg.detach().cpu().item() / self.world_size
|
||||
self.val_acc_avg = val_acc_avg.detach().cpu().item() / self.world_size
|
||||
|
||||
if kwargs.get("step_in_epoch", None) is None:
|
||||
ckpt_name = f"model.pt.ep{epoch}"
|
||||
else:
|
||||
ckpt_name = f'model.pt.ep{epoch}.{kwargs.get("step_in_epoch")}'
|
||||
self.val_acc_step_or_eoch[ckpt_name] = self.val_acc_avg
|
||||
self.val_loss_step_or_eoch[ckpt_name] = self.val_loss_avg
|
||||
|
||||
if self.use_ddp or self.use_fsdp or self.use_deepspeed:
|
||||
dist.barrier()
|
||||
|
||||
model.train()
|
||||
|
||||
def log(
|
||||
self,
|
||||
loss_dict: dict = None,
|
||||
tag="train",
|
||||
**kwargs,
|
||||
):
|
||||
loss = loss_dict["loss"].detach().cpu().item()
|
||||
epoch = loss_dict["epoch"]
|
||||
batch_idx = loss_dict["batch_idx"]
|
||||
step_in_epoch = loss_dict["step_in_epoch"]
|
||||
batch_total = loss_dict["batch_total"]
|
||||
batch_num_epoch = loss_dict["batch_num_epoch"]
|
||||
lr = loss_dict["lr"]
|
||||
|
||||
speed_stats = loss_dict["speed_stats"]
|
||||
stats = loss_dict["stats"]
|
||||
data_split_i = loss_dict["data_split_i"]
|
||||
data_split_num = loss_dict["data_split_num"]
|
||||
log_step = loss_dict.get("log_step", None)
|
||||
|
||||
if (batch_idx + 1) % self.log_interval == 0:
|
||||
batch_idx = log_step if log_step is not None else batch_idx
|
||||
gpu_info = (
|
||||
"GPU, memory: usage: {:.3f} GB, "
|
||||
"peak: {:.3f} GB, "
|
||||
"cache: {:.3f} GB, "
|
||||
"cache_peak: {:.3f} GB".format(
|
||||
torch.cuda.memory_allocated() / 1024 / 1024 / 1024,
|
||||
torch.cuda.max_memory_allocated() / 1024 / 1024 / 1024,
|
||||
torch.cuda.memory_reserved() / 1024 / 1024 / 1024,
|
||||
torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024,
|
||||
)
|
||||
)
|
||||
|
||||
loss_avg_epoch = getattr(self, f"{tag}_loss_avg")
|
||||
acc_avg_epoch = getattr(self, f"{tag}_acc_avg")
|
||||
description = (
|
||||
f"{tag}, "
|
||||
f"rank: {self.rank}, "
|
||||
f"epoch: {epoch}/{self.max_epoch}, "
|
||||
f"data_slice: {data_split_i}/{data_split_num}, "
|
||||
f"step_in_slice: {batch_idx + 1}/{batch_num_epoch}, step_in_epoch: {step_in_epoch}, total step: {batch_total}, "
|
||||
f"(loss_avg_rank: {loss:.3f}), "
|
||||
f"(loss_avg_slice: {loss_avg_epoch:.3f}), "
|
||||
f"(ppl_avg_slice: {math.exp(loss_avg_epoch):.3e}), "
|
||||
f"(acc_avg_slice: {acc_avg_epoch:.3f}), "
|
||||
f"(lr: {lr:.3e}), "
|
||||
f"{[(k, round(v.detach().cpu().item(), 3)) for k, v in stats.items()]}, "
|
||||
f"{speed_stats}, "
|
||||
f"{gpu_info}"
|
||||
)
|
||||
logging.info(description)
|
||||
|
||||
description_dict = {
|
||||
f"rank{self.rank}_loss/{tag}": loss,
|
||||
f"rank{self.rank}_lr/{tag}": lr,
|
||||
}
|
||||
|
||||
writer = self.writer
|
||||
if writer is not None:
|
||||
writer.add_scalar(f"rank{self.rank}_loss/{tag}", loss, batch_total)
|
||||
writer.add_scalar(f"rank{self.rank}_lr/{tag}", lr, batch_total)
|
||||
for key, var in stats.items():
|
||||
writer.add_scalar(f"stats_rank{self.rank}_{key}/{tag}", var.item(), batch_total)
|
||||
description_dict[f"stats_rank{self.rank}_{key}/{tag}"] = var.item()
|
||||
for key, var in speed_stats.items():
|
||||
writer.add_scalar(f"stats_rank{self.rank}_{key}/{tag}", eval(var), batch_total)
|
||||
description_dict[f"stats_rank{self.rank}_{key}/{tag}"] = eval(var)
|
||||
if self.use_wandb and wandb is not None:
|
||||
wandb.log(
|
||||
description_dict,
|
||||
setp=batch_total,
|
||||
)
|
||||
|
||||
def close(self, writer=None):
|
||||
|
||||
if self.use_ddp or self.use_fsdp:
|
||||
dist.barrier()
|
||||
|
||||
if writer is not None:
|
||||
writer.close()
|
||||
|
||||
if self.use_ddp or self.use_fsdp:
|
||||
torch.distributed.destroy_process_group()
|
||||
|
||||
def warp_model(self, model, **kwargs):
|
||||
|
||||
if self.use_deepspeed:
|
||||
from deepspeed.runtime.zero.stage_1_and_2 import (
|
||||
estimate_zero2_model_states_mem_needs_all_live,
|
||||
)
|
||||
from deepspeed.runtime.zero.stage3 import estimate_zero3_model_states_mem_needs_all_live
|
||||
from deepspeed.utils.zero_to_fp32 import convert_zero_checkpoint_to_fp32_state_dict
|
||||
|
||||
local_world_size = int(os.environ.get("LOCAL_WORLD_SIZE", 1))
|
||||
world_size = int(os.environ.get("WORLD_SIZE", 1))
|
||||
|
||||
# NOTE(xcsong): look in detail how the memory estimator API works:
|
||||
# https://deepspeed.readthedocs.io/en/latest/memory.html#discussion
|
||||
if int(os.environ.get("RANK", 0)) == 0:
|
||||
logging.info("Estimating model states memory needs (zero2)...")
|
||||
estimate_zero2_model_states_mem_needs_all_live(
|
||||
model,
|
||||
num_gpus_per_node=local_world_size,
|
||||
num_nodes=world_size // local_world_size,
|
||||
)
|
||||
logging.info("Estimating model states memory needs (zero3)...")
|
||||
estimate_zero3_model_states_mem_needs_all_live(
|
||||
model,
|
||||
num_gpus_per_node=local_world_size,
|
||||
num_nodes=world_size // local_world_size,
|
||||
)
|
||||
device = None # Init device later
|
||||
pass # Init DeepSpeed later
|
||||
|
||||
elif self.use_ddp:
|
||||
local_rank = int(os.environ.get("LOCAL_RANK", 0))
|
||||
model = model.cuda(local_rank)
|
||||
model = DDP(
|
||||
model,
|
||||
device_ids=[local_rank],
|
||||
find_unused_parameters=kwargs.get("train_conf", {}).get(
|
||||
"find_unused_parameters", False
|
||||
),
|
||||
)
|
||||
|
||||
else:
|
||||
model = model.to(device=kwargs.get("device", "cuda"))
|
||||
|
||||
return model
|
||||
|
||||
def warp_optim_scheduler(self, model, **kwargs):
|
||||
from funasr.optimizers import optim_classes
|
||||
from funasr.schedulers import scheduler_classes
|
||||
from omegaconf import OmegaConf, DictConfig
|
||||
import json
|
||||
|
||||
# optim
|
||||
logging.info("Build optim")
|
||||
optim = kwargs.get("optim", "adam")
|
||||
assert optim in optim_classes
|
||||
optim_class = optim_classes.get(optim)
|
||||
optim = optim_class(model.parameters(), **kwargs.get("optim_conf"))
|
||||
|
||||
# scheduler
|
||||
logging.info("Build scheduler")
|
||||
scheduler = kwargs.get("scheduler", "warmuplr")
|
||||
assert scheduler in scheduler_classes
|
||||
scheduler_class = scheduler_classes.get(scheduler)
|
||||
scheduler = scheduler_class(optim, **kwargs.get("scheduler_conf"))
|
||||
|
||||
if self.use_deepspeed:
|
||||
import deepspeed
|
||||
|
||||
args = OmegaConf.create({"deepspeed_config": self.deepspeed_config})
|
||||
with open(self.deepspeed_config, "r") as fin:
|
||||
ds_configs = json.load(fin)
|
||||
|
||||
if "bf16" in ds_configs and ds_configs["bf16"]["enabled"]:
|
||||
self.dtype = torch.bfloat16
|
||||
|
||||
if "fp16" in ds_configs and ds_configs["fp16"]["enabled"]:
|
||||
self.dtype = torch.float16
|
||||
if "optimizer" in ds_configs:
|
||||
# NOTE(xcsong): Disable custom optimizer if it is set in ds_config,
|
||||
# extremely useful when enable cpu_offload, DeepspeedCpuAdam
|
||||
# could be 4~5x faster than torch native adam
|
||||
optim = None
|
||||
if "scheduler" in ds_configs:
|
||||
scheduler = None
|
||||
else:
|
||||
|
||||
def scheduler(opt):
|
||||
return scheduler_class(opt, **kwargs.get("scheduler_conf"))
|
||||
|
||||
model, optimizer, _, scheduler = deepspeed.initialize(
|
||||
args=args,
|
||||
model=model,
|
||||
optimizer=optim,
|
||||
lr_scheduler=scheduler,
|
||||
model_parameters=model.parameters(),
|
||||
)
|
||||
|
||||
return model, optim, scheduler
|
||||
Reference in New Issue
Block a user