Sync from bytedesk-private: update

This commit is contained in:
jack ning
2024-12-14 10:43:18 +08:00
parent 476eebb101
commit 5e082909e4
3421 changed files with 812709 additions and 0 deletions

View File

@@ -0,0 +1,31 @@
import torch
def add_gradient_noise(
model: torch.nn.Module,
iteration: int,
duration: float = 100,
eta: float = 1.0,
scale_factor: float = 0.55,
):
"""Adds noise from a standard normal distribution to the gradients.
The standard deviation (`sigma`) is controlled
by the three hyper-parameters below.
`sigma` goes to zero (no noise) with more iterations.
Args:
model: Model.
iteration: Number of iterations.
duration: {100, 1000}: Number of durations to control
the interval of the `sigma` change.
eta: {0.01, 0.3, 1.0}: The magnitude of `sigma`.
scale_factor: {0.55}: The scale of `sigma`.
"""
interval = (iteration // duration) + 1
sigma = eta / interval**scale_factor
for param in model.parameters():
if param.grad is not None:
_shape = param.grad.size()
noise = sigma * torch.randn(_shape).to(param.device)
param.grad += noise

View File

@@ -0,0 +1,97 @@
import logging
from pathlib import Path
from typing import Optional
from typing import Sequence
from typing import Union
import warnings
import os
from io import BytesIO
import torch
from typing import Collection
import os
import torch
import re
from collections import OrderedDict
from functools import cmp_to_key
def _get_checkpoint_paths(output_dir: str, last_n: int = 5, use_deepspeed=False, **kwargs):
"""
Get the paths of the last 'last_n' checkpoints by parsing filenames
in the output directory.
"""
try:
if not use_deepspeed:
checkpoint = torch.load(os.path.join(output_dir, "model.pt"), map_location="cpu")
else:
checkpoint = torch.load(
os.path.join(output_dir, "model.pt", "mp_rank_00_model_states.pt"),
map_location="cpu",
)
avg_keep_nbest_models_type = checkpoint["avg_keep_nbest_models_type"]
val_step_or_epoch = checkpoint[f"val_{avg_keep_nbest_models_type}_step_or_epoch"]
sorted_items = sorted(val_step_or_epoch.items(), key=lambda x: x[1], reverse=True)
sorted_items = (
sorted_items[:last_n] if avg_keep_nbest_models_type == "acc" else sorted_items[-last_n:]
)
checkpoint_paths = []
for key, value in sorted_items[:last_n]:
if not use_deepspeed:
ckpt = os.path.join(output_dir, key)
else:
ckpt = os.path.join(output_dir, key, "mp_rank_00_model_states.pt")
checkpoint_paths.append(ckpt)
except:
print(f"{checkpoint} does not exist, avg the lastet checkpoint.")
# List all files in the output directory
files = os.listdir(output_dir)
# Filter out checkpoint files and extract epoch numbers
checkpoint_files = [f for f in files if f.startswith("model.pt.e")]
# Sort files by epoch number in descending order
checkpoint_files.sort(key=lambda x: int(re.search(r"(\d+)", x).group()), reverse=True)
# Get the last 'last_n' checkpoint paths
checkpoint_paths = [os.path.join(output_dir, f) for f in checkpoint_files[:last_n]]
return checkpoint_paths
@torch.no_grad()
def average_checkpoints(output_dir: str, last_n: int = 5, **kwargs):
"""
Average the last 'last_n' checkpoints' model state_dicts.
If a tensor is of type torch.int, perform sum instead of average.
"""
checkpoint_paths = _get_checkpoint_paths(output_dir, last_n, **kwargs)
print(f"average_checkpoints: {checkpoint_paths}")
state_dicts = []
# Load state_dicts from checkpoints
for path in checkpoint_paths:
if os.path.isfile(path):
state_dicts.append(torch.load(path, map_location="cpu")["state_dict"])
else:
print(f"Checkpoint file {path} not found.")
# Check if we have any state_dicts to average
if len(state_dicts) < 1:
print("No checkpoints found for averaging.")
return
# Average or sum weights
avg_state_dict = OrderedDict()
for key in state_dicts[0].keys():
tensors = [state_dict[key].cpu() for state_dict in state_dicts]
# Check the type of the tensor
if str(tensors[0].dtype).startswith("torch.int"):
# Perform sum for integer tensors
summed_tensor = sum(tensors)
avg_state_dict[key] = summed_tensor
else:
# Perform average for other types of tensors
stacked_tensors = torch.stack(tensors)
avg_state_dict[key] = torch.mean(stacked_tensors, dim=0)
checkpoint_outpath = os.path.join(output_dir, f"model.pt.avg{last_n}")
torch.save({"state_dict": avg_state_dict}, checkpoint_outpath)
return checkpoint_outpath

View File

@@ -0,0 +1,64 @@
import dataclasses
import warnings
import numpy as np
import torch
def to_device(data, device=None, dtype=None, non_blocking=False, copy=False):
"""Change the device of object recursively"""
if isinstance(data, dict):
return {k: to_device(v, device, dtype, non_blocking, copy) for k, v in data.items()}
elif dataclasses.is_dataclass(data) and not isinstance(data, type):
return type(data)(
*[to_device(v, device, dtype, non_blocking, copy) for v in dataclasses.astuple(data)]
)
# maybe namedtuple. I don't know the correct way to judge namedtuple.
elif isinstance(data, tuple) and type(data) is not tuple:
return type(data)(*[to_device(o, device, dtype, non_blocking, copy) for o in data])
elif isinstance(data, (list, tuple)):
return type(data)(to_device(v, device, dtype, non_blocking, copy) for v in data)
elif isinstance(data, np.ndarray):
return to_device(torch.from_numpy(data), device, dtype, non_blocking, copy)
elif isinstance(data, torch.Tensor):
return data.to(device, dtype, non_blocking, copy)
else:
return data
def force_gatherable(data, device):
"""Change object to gatherable in torch.nn.DataParallel recursively
The difference from to_device() is changing to torch.Tensor if float or int
value is found.
The restriction to the returned value in DataParallel:
The object must be
- torch.cuda.Tensor
- 1 or more dimension. 0-dimension-tensor sends warning.
or a list, tuple, dict.
"""
if isinstance(data, dict):
return {k: force_gatherable(v, device) for k, v in data.items()}
# DataParallel can't handle NamedTuple well
elif isinstance(data, tuple) and type(data) is not tuple:
return type(data)(*[force_gatherable(o, device) for o in data])
elif isinstance(data, (list, tuple, set)):
return type(data)(force_gatherable(v, device) for v in data)
elif isinstance(data, np.ndarray):
return force_gatherable(torch.from_numpy(data), device)
elif isinstance(data, torch.Tensor):
if data.dim() == 0:
# To 1-dim array
data = data[None]
return data.to(device)
elif isinstance(data, float):
return torch.tensor([data], dtype=torch.float, device=device)
elif isinstance(data, int):
return torch.tensor([data], dtype=torch.long, device=device)
elif data is None:
return None
else:
warnings.warn(f"{type(data)} may not be gatherable by DataParallel")
return data

View File

@@ -0,0 +1,31 @@
import torch
class ForwardAdaptor(torch.nn.Module):
"""Wrapped module to parallelize specified method
torch.nn.DataParallel parallelizes only "forward()"
and, maybe, the method having the other name can't be applied
except for wrapping the module just like this class.
Examples:
>>> class A(torch.nn.Module):
... def foo(self, x):
... ...
>>> model = A()
>>> model = ForwardAdaptor(model, "foo")
>>> model = torch.nn.DataParallel(model, device_ids=[0, 1])
>>> x = torch.randn(2, 10)
>>> model(x)
"""
def __init__(self, module: torch.nn.Module, name: str):
super().__init__()
self.module = module
self.name = name
if not hasattr(module, name):
raise ValueError(f"{module} doesn't have {name}")
def forward(self, *args, **kwargs):
func = getattr(self.module, self.name)
return func(*args, **kwargs)

View File

@@ -0,0 +1,55 @@
#!/usr/bin/env python3
"""Initialize modules for espnet2 neural networks."""
import math
import torch
def initialize(model: torch.nn.Module, init: str):
"""Initialize weights of a neural network module.
Parameters are initialized using the given method or distribution.
Custom initialization routines can be implemented into submodules
as function `espnet_initialization_fn` within the custom module.
Args:
model: Target.
init: Method of initialization.
"""
# weight init
for p in model.parameters():
if p.dim() > 1:
if init == "xavier_uniform":
torch.nn.init.xavier_uniform_(p.data)
elif init == "xavier_normal":
torch.nn.init.xavier_normal_(p.data)
elif init == "kaiming_uniform":
torch.nn.init.kaiming_uniform_(p.data, nonlinearity="relu")
elif init == "kaiming_normal":
torch.nn.init.kaiming_normal_(p.data, nonlinearity="relu")
else:
raise ValueError("Unknown initialization: " + init)
# bias init
for p in model.parameters():
if p.dim() == 1:
p.data.zero_()
# reset some modules with default init
for m in model.modules():
if isinstance(m, (torch.nn.Embedding, torch.nn.LayerNorm, torch.nn.GroupNorm)):
m.reset_parameters()
if hasattr(m, "espnet_initialization_fn"):
m.espnet_initialization_fn()
# TODO(xkc): Hacking s3prl_frontend and wav2vec2encoder initialization
if getattr(model, "encoder", None) and getattr(
model.encoder, "reload_pretrained_parameters", None
):
model.encoder.reload_pretrained_parameters()
if getattr(model, "frontend", None) and getattr(
model.frontend, "reload_pretrained_parameters", None
):
model.frontend.reload_pretrained_parameters()

View File

@@ -0,0 +1,103 @@
from typing import Any
from typing import Dict
from typing import Union
from io import BytesIO
import logging
import torch
import torch.nn
import torch.optim
import pdb
import copy
def load_pretrained_model(
path: str,
model: torch.nn.Module,
ignore_init_mismatch: bool = True,
map_location: str = "cpu",
oss_bucket=None,
scope_map=[],
excludes=None,
**kwargs,
):
"""Load a model state and set it to the model.
Args:
init_param: <file_path>:<src_key>:<dst_key>:<exclude_Keys>
Examples:
"""
obj = model
dst_state = obj.state_dict()
logging.info(f"ckpt: {path}")
if oss_bucket is None:
ori_state = torch.load(path, map_location=map_location)
else:
buffer = BytesIO(oss_bucket.get_object(path).read())
ori_state = torch.load(buffer, map_location=map_location)
src_state = copy.deepcopy(ori_state)
src_state = src_state["state_dict"] if "state_dict" in src_state else src_state
src_state = src_state["model_state_dict"] if "model_state_dict" in src_state else src_state
src_state = src_state["model"] if "model" in src_state else src_state
if isinstance(scope_map, str):
scope_map = scope_map.split(",")
scope_map += ["module.", "None"]
logging.info(f"scope_map: {scope_map}")
if excludes is not None:
if isinstance(excludes, str):
excludes = excludes.split(",")
logging.info(f"excludes: {excludes}")
for k in dst_state.keys():
excludes_flag = False
if excludes is not None:
for k_ex in excludes:
if k.startswith(k_ex):
logging.info(f"key: {k} matching: {k_ex}, excluded")
excludes_flag = True
break
if excludes_flag:
continue
k_src = k
if scope_map is not None:
src_prefix = ""
dst_prefix = ""
for i in range(0, len(scope_map), 2):
src_prefix = scope_map[i] if scope_map[i].lower() != "none" else ""
dst_prefix = scope_map[i + 1] if scope_map[i + 1].lower() != "none" else ""
if dst_prefix == "" and (src_prefix + k) in src_state.keys():
k_src = src_prefix + k
if not k_src.startswith("module."):
logging.info(f"init param, map: {k} from {k_src} in ckpt")
elif (
k.startswith(dst_prefix)
and k.replace(dst_prefix, src_prefix, 1) in src_state.keys()
):
k_src = k.replace(dst_prefix, src_prefix, 1)
if not k_src.startswith("module."):
logging.info(f"init param, map: {k} from {k_src} in ckpt")
if k_src in src_state.keys():
if ignore_init_mismatch and dst_state[k].shape != src_state[k_src].shape:
logging.info(
f"ignore_init_mismatch:{ignore_init_mismatch}, dst: {k, dst_state[k].shape}, src: {k_src, src_state[k_src].shape}"
)
else:
dst_state[k] = src_state[k_src]
else:
print(f"Warning, miss key in ckpt: {k}, {path}")
flag = obj.load_state_dict(dst_state, strict=True)
logging.info(f"Loading ckpt: {path}, status: {flag}")

View File

@@ -0,0 +1,72 @@
import numpy as np
import torch
def get_human_readable_count(number: int) -> str:
"""Return human_readable_count
Originated from:
https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pytorch_lightning/core/memory.py
Abbreviates an integer number with K, M, B, T for thousands, millions,
billions and trillions, respectively.
Examples:
>>> get_human_readable_count(123)
'123 '
>>> get_human_readable_count(1234) # (one thousand)
'1 K'
>>> get_human_readable_count(2e6) # (two million)
'2 M'
>>> get_human_readable_count(3e9) # (three billion)
'3 B'
>>> get_human_readable_count(4e12) # (four trillion)
'4 T'
>>> get_human_readable_count(5e15) # (more than trillion)
'5,000 T'
Args:
number: a positive integer number
Return:
A string formatted according to the pattern described above.
"""
assert number >= 0
labels = [" ", "K", "M", "B", "T"]
num_digits = int(np.floor(np.log10(number)) + 1 if number > 0 else 1)
num_groups = int(np.ceil(num_digits / 3))
num_groups = min(num_groups, len(labels)) # don't abbreviate beyond trillions
shift = -3 * (num_groups - 1)
number = number * (10**shift)
index = num_groups - 1
return f"{number:.2f} {labels[index]}"
def to_bytes(dtype) -> int:
# torch.float16 -> 16
return int(str(dtype)[-2:]) // 8
def model_summary(model: torch.nn.Module) -> str:
message = "Model structure:\n"
message += str(model)
tot_params, num_params = 0, 0
for name, param in model.named_parameters():
print(
"name: {}, dtype: {}, device: {}, trainable: {}, shape: {}, numel: {}".format(
name, param.dtype, param.device, param.requires_grad, param.shape, param.numel()
)
)
tot_params += param.numel()
if param.requires_grad:
num_params += param.numel()
percent_trainable = "{:.1f}".format(num_params * 100.0 / tot_params)
tot_params = get_human_readable_count(tot_params)
num_params = get_human_readable_count(num_params)
message += "\n\nModel summary:\n"
message += f" Class Name: {model.__class__.__name__}\n"
message += f" Total Number of model parameters: {tot_params}\n"
message += f" Number of trainable parameters: {num_params} ({percent_trainable}%)\n"
dtype = next(iter(model.parameters())).dtype
message += f" Type: {dtype}"
return message

View File

@@ -0,0 +1,48 @@
"""Torch utility module."""
import torch
if torch.distributed.is_available():
from torch.distributed import ReduceOp
def recursive_sum(obj, weight: torch.Tensor, distributed: bool = False):
assert weight.dim() == 1, weight.size()
if isinstance(obj, (tuple, list)):
return type(obj)(recursive_sum(v, weight, distributed) for v in obj)
elif isinstance(obj, dict):
return {k: recursive_sum(v, weight, distributed) for k, v in obj.items()}
elif isinstance(obj, torch.Tensor):
assert obj.size() == weight.size(), (obj.size(), weight.size())
obj = (obj * weight.type(obj.dtype)).sum()
if distributed:
torch.distributed.all_reduce(obj, op=ReduceOp.SUM)
return obj
elif obj is None:
return None
else:
raise ValueError(type(obj))
def recursive_divide(a, b: torch.Tensor):
if isinstance(a, (tuple, list)):
return type(a)(recursive_divide(v, b) for v in a)
elif isinstance(a, dict):
return {k: recursive_divide(v, b) for k, v in a.items()}
elif isinstance(a, torch.Tensor):
assert a.size() == b.size(), (a.size(), b.size())
return a / b.type(a.dtype)
elif a is None:
return None
else:
raise ValueError(type(a))
def recursive_average(obj, weight: torch.Tensor, distributed: bool = False):
obj = recursive_sum(obj, weight, distributed)
weight = weight.sum()
if distributed:
torch.distributed.all_reduce(weight, op=ReduceOp.SUM)
# Normalize weight to be sum-to-1
obj = recursive_divide(obj, weight)
return obj, weight

View File

@@ -0,0 +1,10 @@
import random
import numpy as np
import torch
def set_all_random_seed(seed: int):
random.seed(seed)
np.random.seed(seed)
torch.random.manual_seed(seed)

View File

@@ -0,0 +1,732 @@
import math
import os
import time
import torch
import logging
from tqdm import tqdm
from datetime import datetime
import torch.distributed as dist
from torch.cuda.amp import autocast, GradScaler
from contextlib import nullcontext, contextmanager
from pathlib import Path
from funasr.train_utils.device_funcs import to_device
from funasr.train_utils.recursive_op import recursive_average
from funasr.train_utils.average_nbest_models import average_checkpoints
from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
try:
import wandb
except:
wandb = None
@contextmanager
def maybe_autocast(enabled):
if enabled:
with autocast():
yield
else:
yield
class Trainer:
"""
A simple trainer class for training a PyTorch model, saving checkpoints at the end of each epoch,
and optionally resuming from a saved checkpoint.
Attributes:
max_epoch (int): Maximum number of epochs for training.
model (torch.nn.Module): The model to be trained.
optim (torch.optim.Optimizer): The optimizer to use for training.
scheduler (torch.optim.lr_scheduler._LRScheduler): The learning rate scheduler.
dataloader_train (torch.utils.data.DataLoader): DataLoader for the training dataset.
dataloader_val (torch.utils.data.DataLoader): DataLoader for the validation dataset.
output_dir (str): Directory where model checkpoints will be saved.
resume (str, optional): Path to a checkpoint to resume training from.
"""
def __init__(
self,
local_rank,
use_ddp: bool = False,
use_fsdp: bool = False,
use_fp16: bool = False,
output_dir: str = "./",
**kwargs,
):
"""
Initializes the Trainer class with the model, optimizer, scheduler, dataloaders, and other settings.
Args:
model (torch.nn.Module): The model to be trained.
optim (torch.optim.Optimizer): The optimizer to use for training.
scheduler (torch.optim.lr_scheduler._LRScheduler): The learning rate scheduler.
dataloader_train (torch.utils.data.DataLoader): The DataLoader for the training dataset.
dataloader_val (torch.utils.data.DataLoader): The DataLoader for the validation dataset.
**kwargs: Additional keyword arguments:
max_epoch (int): The maximum number of epochs for training.
output_dir (str): The directory where model checkpoints will be saved. Default is './'.
resume (str, optional): The file path to a checkpoint to resume training from.
"""
self.output_dir = output_dir
if not os.path.exists(self.output_dir):
os.makedirs(self.output_dir, exist_ok=True)
self.resume = kwargs.get("resume", True)
self.start_epoch = 0
self.max_epoch = kwargs.get("max_epoch", 100)
self.local_rank = local_rank
self.use_ddp = use_ddp
self.use_fsdp = use_fsdp
self.device = kwargs.get("device", "cuda")
# self.kwargs = kwargs
self.log_interval = kwargs.get("log_interval", 50)
self.batch_total = 0
self.use_fp16 = use_fp16
self.save_checkpoint_interval = kwargs.get("save_checkpoint_interval", 5000)
self.validate_interval = kwargs.get("validate_interval", -1)
if self.validate_interval < 0:
self.validate_interval = self.save_checkpoint_interval
assert (
self.save_checkpoint_interval == self.validate_interval
), f"save_checkpoint_interval must equal to validate_interval"
self.keep_nbest_models = kwargs.get("keep_nbest_models", 500)
self.avg_keep_nbest_models_type = kwargs.get("avg_keep_nbest_models_type", "acc")
self.avg_nbest_model = kwargs.get("avg_nbest_model", 10)
self.accum_grad = kwargs.get("accum_grad", 1)
self.grad_clip = kwargs.get("grad_clip", 10.0)
self.grad_clip_type = kwargs.get("grad_clip_type", 2.0)
try:
rank = dist.get_rank()
world_size = dist.get_world_size()
except:
rank = 0
world_size = 1
logging.warning("distributed is not initialized, only single shard")
self.rank = rank
self.world_size = world_size
self.train_acc_avg = 0.0
self.train_loss_avg = 0.0
self.val_acc_avg = 0.0
self.val_loss_avg = 0.0
self.best_acc_idx = 0
self.saved_ckpts = {}
self.step_or_epoch = -1
self.best_step_or_epoch = ""
self.val_acc_step_or_epoch = {}
self.val_loss_step_or_epoch = {}
self.reset_gpu_cache = kwargs.get("reset_gpu_cache", False)
self.start_data_split_i = 0
self.start_step = 0
self.step_in_epoch = 0
self.use_wandb = kwargs.get("use_wandb", False)
if self.use_wandb:
wandb.login(key=kwargs.get("wandb_token"))
wandb.init(
config=kwargs,
project=kwargs.get("wandb_project", "my_project"),
entity=kwargs.get("wandb_team", "my_team"),
name=kwargs.get("wandb_exp_name", "my_exp"),
dir=output_dir,
job_type="training",
reinit=True,
)
def save_checkpoint(
self,
epoch,
step=None,
model=None,
optim=None,
scheduler=None,
scaler=None,
step_in_epoch=None,
**kwargs,
):
"""
Saves a checkpoint containing the model's state, the optimizer's state,
and the scheduler's state at the end of the given epoch. This method is
intended to be called at the end of each epoch to save the training progress.
Args:
epoch (int): The epoch number at which the checkpoint is being saved.
"""
step_in_epoch = None if step is None else step_in_epoch
if self.rank == 0:
logging.info(f"Save checkpoint: {epoch}, rank: {self.local_rank}\n")
# self.step_or_epoch += 1
state = {
"epoch": epoch,
'step': step,
'total_step': self.batch_total,
"state_dict": model.state_dict(),
"optimizer": optim.state_dict(),
"scheduler": scheduler.state_dict(),
"saved_ckpts": self.saved_ckpts,
"val_acc_step_or_epoch": self.val_acc_step_or_epoch,
"val_loss_step_or_epoch": self.val_loss_step_or_epoch,
"best_step_or_epoch": self.best_step_or_epoch,
"avg_keep_nbest_models_type": self.avg_keep_nbest_models_type,
"step": step,
"step_in_epoch": step_in_epoch,
"data_split_i": kwargs.get("data_split_i", 0),
"data_split_num": kwargs.get("data_split_num", 1),
"batch_total": self.batch_total,
"train_loss_avg": kwargs.get("train_loss_avg", 0),
"train_acc_avg": kwargs.get("train_acc_avg", 0),
}
step = step_in_epoch
if hasattr(model, "module"):
state["state_dict"] = model.module.state_dict()
if scaler:
state["scaler_state"] = scaler.state_dict()
# Create output directory if it does not exist
os.makedirs(self.output_dir, exist_ok=True)
if step is None:
ckpt_name = f"model.pt.ep{epoch}"
else:
ckpt_name = f"model.pt.ep{epoch}.{step}"
filename = os.path.join(self.output_dir, ckpt_name)
torch.save(state, filename)
logging.info(f'Checkpoint saved to {filename}')
latest = Path(os.path.join(self.output_dir, f'model.pt'))
torch.save(state, latest)
if self.best_step_or_epoch == "":
self.best_step_or_epoch = ckpt_name
if self.avg_keep_nbest_models_type == "acc":
if (
self.val_acc_step_or_epoch[ckpt_name]
>= self.val_acc_step_or_epoch[self.best_step_or_epoch]
):
self.best_step_or_epoch = ckpt_name
best_ckpt = Path(os.path.join(self.output_dir, f"model.pt.best"))
torch.save(state, best_ckpt)
logging.info(
f"Update best acc: {self.val_acc_step_or_epoch[self.best_step_or_epoch]:.4f}, {best_ckpt}"
)
else:
logging.info(
f"No improvement in acc: {self.val_acc_step_or_epoch[ckpt_name]:.4f} < {self.val_acc_step_or_epoch[self.best_step_or_epoch]:.4f}, {os.path.join(self.output_dir, self.best_step_or_epoch)}"
)
elif self.avg_keep_nbest_models_type == "loss":
if (
self.val_loss_step_or_epoch[ckpt_name]
<= self.val_loss_step_or_epoch[self.best_step_or_epoch]
):
self.best_step_or_epoch = ckpt_name
best_ckpt = Path(os.path.join(self.output_dir, f"model.pt.best"))
torch.save(state, best_ckpt)
logging.info(
f"Update best loss: {self.val_loss_step_or_epoch[self.best_step_or_epoch]:.4f}, {best_ckpt}"
)
else:
logging.info(
f"No improvement in loss: {self.val_loss_step_or_epoch[ckpt_name]:.4f} > {self.val_loss_step_or_epoch[self.best_step_or_epoch]:.4f}, {os.path.join(self.output_dir, self.best_step_or_epoch)}"
)
else:
print("Undo")
self.saved_ckpts[ckpt_name] = getattr(
self, f"val_{self.avg_keep_nbest_models_type}_step_or_epoch"
)[ckpt_name]
if self.keep_nbest_models > 0:
if len(self.saved_ckpts) > self.keep_nbest_models:
if self.avg_keep_nbest_models_type == "acc":
key = min(self.saved_ckpts, key=self.saved_ckpts.get)
else:
key = max(self.saved_ckpts, key=self.saved_ckpts.get)
if key in self.saved_ckpts:
del self.saved_ckpts[key]
filename = os.path.join(self.output_dir, key)
logging.info(f"Delete: {filename}")
if os.path.exists(filename):
os.remove(filename)
if self.use_ddp or self.use_fsdp:
dist.barrier()
def resume_checkpoint(
self,
model=None,
optim=None,
scheduler=None,
scaler=None,
):
"""
Resumes training from a checkpoint at the given file path.
Loads the model's state, the optimizer's state, and the scheduler's state.
Args:
resume_path (str): The file path to the checkpoint to resume from.
"""
if self.resume:
ckpt = os.path.join(self.output_dir, "model.pt")
if os.path.isfile(ckpt):
checkpoint = torch.load(ckpt, map_location="cpu")
self.start_epoch = checkpoint["epoch"]
# self.model.load_state_dict(checkpoint['state_dict'])
src_state = checkpoint["state_dict"]
dst_state = model.state_dict()
for k in dst_state.keys():
if not k.startswith("module.") and "module." + k in src_state.keys():
k_ddp = "module." + k
elif k.startswith("module.") and "module." + k not in src_state.keys():
k_ddp = k.replace("module.", "", 1)
else:
k_ddp = k
if k_ddp in src_state.keys():
dst_state[k] = src_state[k_ddp]
else:
print(f"Miss key in ckpt: model: {k}, ckpt: {k_ddp}")
model.load_state_dict(dst_state)
optim.load_state_dict(checkpoint["optimizer"])
scheduler.load_state_dict(checkpoint["scheduler"])
if scaler is not None and "scaler_state" in checkpoint:
scaler.load_state_dict(checkpoint["scaler_state"])
self.saved_ckpts = checkpoint["saved_ckpts"]
self.val_acc_step_or_epoch = (
checkpoint["val_acc_step_or_epoch"]
if "val_acc_step_or_epoch" in checkpoint
else {}
)
self.val_loss_step_or_epoch = (
checkpoint["val_loss_step_or_epoch"]
if "val_loss_step_or_epoch" in checkpoint
else {}
)
self.best_step_or_epoch = (
checkpoint["best_step_or_epoch"] if "best_step_or_epoch" in checkpoint else ""
)
self.start_data_split_i = (
checkpoint["data_split_i"] if "data_split_i" in checkpoint else 0
)
self.batch_total = checkpoint["batch_total"] if "batch_total" in checkpoint else 0
self.start_step = checkpoint["step"] if "step" in checkpoint else 0
self.start_step = 0 if self.start_step is None else self.start_step
self.step_in_epoch = (
checkpoint["step_in_epoch"] if "step_in_epoch" in checkpoint else 0
)
self.step_in_epoch = 0 if self.step_in_epoch is None else self.step_in_epoch
print(checkpoint["train_acc_avg"])
self.train_acc_avg = (
checkpoint["train_acc_avg"] if "train_acc_avg" in checkpoint else 0
)
self.train_loss_avg = (
checkpoint["train_loss_avg"] if "train_loss_avg" in checkpoint else 0
)
model.to(self.device)
print(f"Checkpoint loaded successfully from '{ckpt}'")
else:
print(f"No checkpoint found at '{ckpt}', does not resume status!")
if self.use_ddp or self.use_fsdp:
dist.barrier()
def train_epoch(
self,
model=None,
optim=None,
scheduler=None,
scaler=None,
dataloader_train=None,
dataloader_val=None,
epoch=None,
writer=None,
**kwargs,
):
"""
Defines the training process for a single epoch with gradient accumulation.
Args:
epoch (int): The current epoch number.
"""
if self.use_ddp or self.use_fsdp:
dist.barrier()
logging.info(f"Train epoch: {epoch}, rank: {self.rank}\n")
model.train()
# Set the number of steps for gradient accumulation
accum_grad = self.accum_grad
# Initialize the gradient accumulation
optim.zero_grad()
speed_stats = {}
iterator_stop = torch.tensor(0).to(self.device)
dataloader_train.batch_sampler.set_epoch(epoch)
time_beg = time.perf_counter()
time5 = time_beg
for batch_idx, batch in enumerate(dataloader_train):
# if self.use_ddp or self.use_fsdp:
# dist.all_reduce(iterator_stop, dist.ReduceOp.SUM)
# if iterator_stop > 0:
# break
self.batch_total += 1
self.step_in_epoch += 1
time1 = time.perf_counter()
speed_stats["data_load"] = f"{time1-time_beg:0.3f}"
batch = to_device(batch, self.device)
my_context = nullcontext
if self.use_ddp or self.use_fsdp:
my_context = model.no_sync if batch_idx % accum_grad != 0 else my_context
with my_context():
time2 = time.perf_counter()
with maybe_autocast(self.use_fp16):
retval = model(**batch)
# if (
# self.reset_gpu_cache
# and (torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024) > 70
# ):
# torch.cuda.empty_cache()
loss, stats, weight = retval
stats = {k: v for k, v in stats.items() if v is not None}
if self.use_ddp or self.use_fsdp:
# Apply weighted averaging for loss and stats
loss = (loss * weight.type(loss.dtype)).sum()
# if distributed, this method can also apply all_reduce()
# stats, weight = recursive_average(stats, weight, distributed=True)
if self.use_ddp or self.use_fsdp:
dist.all_reduce(weight, op=dist.ReduceOp.SUM)
# Now weight is summation over all workers
loss /= weight.sum() # shape:[1] -> shape:[]
# Multiply world_size because DistributedDataParallel
# automatically normalizes the gradient by world_size.
loss *= self.world_size
# loss *= self.world_size
# Scale the loss since we're not updating for every mini-batch
loss = loss / accum_grad
time3 = time.perf_counter()
speed_stats["forward_time"] = f"{time3 - time2:0.3f}"
if self.use_fp16:
scaler.scale(loss).backward()
else:
loss.backward()
time4 = time.perf_counter()
speed_stats["backward_and_AllReaduce_time"] = f"{time4 - time3:0.3f}"
self.train_loss_avg = (
self.train_loss_avg * (batch_idx + kwargs.get("start_step", 0))
+ loss.detach().cpu().item()
) / (batch_idx + kwargs.get("start_step", 0) + 1)
if "acc" in stats:
self.train_acc_avg = (
self.train_acc_avg * (batch_idx + kwargs.get("start_step", 0))
+ stats["acc"].detach().cpu().item()
) / (batch_idx + kwargs.get("start_step", 0) + 1)
# Perform an optimizer step only after accumulating enough gradients
if (batch_idx + 1) % accum_grad == 0:
# Perform gradient clipping if it is set
if self.grad_clip > 0:
grad_norm = torch.nn.utils.clip_grad_norm_(
model.parameters(),
max_norm=self.grad_clip,
norm_type=self.grad_clip_type,
)
if not torch.isfinite(grad_norm):
logging.warning(
f"The grad norm is {grad_norm}. Skipping updating the model."
)
optim.zero_grad() # Reset gradients
continue
# Execute an optimization step (update model parameters)
if self.use_ddp or self.use_fsdp:
dist.barrier()
if self.use_fp16:
scaler.step(optim)
scaler.update()
else:
optim.step()
scheduler.step()
# Clear gradients for the next accumulation stage
optim.zero_grad(set_to_none=True)
if self.use_ddp or self.use_fsdp:
train_loss_avg = torch.tensor(self.train_loss_avg, dtype=torch.float32).to(
self.device
)
train_acc_avg = torch.tensor(self.train_acc_avg, dtype=torch.float32).to(
self.device
)
dist.all_reduce(train_loss_avg, op=dist.ReduceOp.SUM)
dist.all_reduce(train_acc_avg, op=dist.ReduceOp.SUM)
self.train_loss_avg = train_loss_avg.detach().cpu().item() / self.world_size
self.train_acc_avg = train_acc_avg.detach().cpu().item() / self.world_size
total_time = f"{(time.perf_counter() - time5)/accum_grad:0.3f}"
time5 = time.perf_counter()
speed_stats["optim_time"] = f"{time5 - time4:0.3f}"
speed_stats["total_time"] = total_time
lr = scheduler.get_last_lr()[0]
batch_num_epoch = 1
if hasattr(dataloader_train, "__len__"):
batch_num_epoch = len(dataloader_train)
self.log(
epoch,
batch_idx,
log_step=batch_idx + kwargs.get("start_step", 0),
step_in_epoch=self.step_in_epoch,
batch_num_epoch=batch_num_epoch,
lr=lr,
loss=accum_grad * loss.detach().cpu().item(),
speed_stats=speed_stats,
stats=stats,
writer=writer,
tag="train",
data_split_i=kwargs.get("data_split_i", 0),
data_split_num=kwargs.get("data_split_num", 1),
)
if self.step_in_epoch % self.validate_interval == 0:
self.validate_epoch(
model=model,
dataloader_val=dataloader_val,
epoch=epoch,
writer=writer,
step=batch_idx + 1,
step_in_epoch=self.step_in_epoch,
)
if self.step_in_epoch % self.save_checkpoint_interval == 0:
self.save_checkpoint(
epoch,
model=model,
optim=optim,
scheduler=scheduler,
scaler=scaler,
step=batch_idx + 1,
step_in_epoch=self.step_in_epoch,
data_split_i=kwargs.get("data_split_i", 0),
data_split_num=kwargs.get("data_split_num", 1),
train_loss_avg=self.train_loss_avg,
train_acc_avg=self.train_acc_avg,
)
time_beg = time.perf_counter()
# else:
# if self.use_ddp or self.use_fsdp:
# iterator_stop.fill_(1)
# dist.all_reduce(iterator_stop, dist.ReduceOp.SUM)
if self.use_ddp or self.use_fsdp:
dist.barrier()
# iterator_stop = torch.tensor(0).to(self.device)
def validate_epoch(
self,
model=None,
dataloader_val=None,
epoch=None,
writer=None,
**kwargs,
):
"""
Defines the validation process for a single epoch.
Should be implemented with the actual model validation steps.
Args:
epoch (int): The current epoch number.
"""
if self.use_ddp or self.use_fsdp:
dist.barrier()
logging.info(f"Validate epoch: {epoch}, rank: {self.rank}\n")
model.eval()
with torch.no_grad():
speed_stats = {}
time5 = time.perf_counter()
iterator_stop = torch.tensor(0).to(self.device)
dataloader_val.batch_sampler.set_epoch(epoch)
for batch_idx, batch in enumerate(dataloader_val):
if self.use_ddp or self.use_fsdp:
dist.all_reduce(iterator_stop, dist.ReduceOp.SUM)
if iterator_stop > 0:
break
time1 = time.perf_counter()
speed_stats["data_load"] = f"{time1 - time5:0.3f}"
batch = to_device(batch, self.device)
time2 = time.perf_counter()
retval = model(**batch)
time3 = time.perf_counter()
speed_stats["forward_time"] = f"{time3 - time2:0.3f}"
loss, stats, weight = retval
stats = {k: v for k, v in stats.items() if v is not None}
if self.use_ddp or self.use_fsdp:
# Apply weighted averaging for loss and stats
loss = (loss * weight.type(loss.dtype)).sum()
# if distributed, this method can also apply all_reduce()
# stats, weight = recursive_average(stats, weight, distributed=True)
if self.use_ddp or self.use_fsdp:
dist.all_reduce(weight, op=dist.ReduceOp.SUM)
# Now weight is summation over all workers
loss /= weight.sum() # shape:[1] -> shape:[]
# Multiply world_size because DistributedDataParallel
# automatically normalizes the gradient by world_size.
loss *= self.world_size
# Scale the loss since we're not updating for every mini-batch
loss = loss
time4 = time.perf_counter()
if torch.isfinite(loss):
self.val_loss_avg = (self.val_loss_avg * batch_idx + loss.detach().cpu().item()) / (
batch_idx + 1
)
if "acc" in stats:
self.val_acc_avg = (
self.val_acc_avg * batch_idx + stats["acc"].detach().cpu().item()
) / (batch_idx + 1)
if self.use_ddp or self.use_fsdp:
val_loss_avg = torch.tensor(self.val_loss_avg, dtype=torch.float32).to(
self.device
)
val_acc_avg = torch.tensor(self.val_acc_avg, dtype=torch.float32).to(
self.device
)
dist.all_reduce(val_loss_avg, op=dist.ReduceOp.SUM)
dist.all_reduce(val_acc_avg, op=dist.ReduceOp.SUM)
self.val_loss_avg = val_loss_avg.detach().cpu().item() / self.world_size
self.val_acc_avg = val_acc_avg.detach().cpu().item() / self.world_size
time5 = time.perf_counter()
batch_num_epoch = 1
if hasattr(dataloader_val, "__len__"):
batch_num_epoch = len(dataloader_val)
self.log(
epoch,
batch_idx,
batch_num_epoch=batch_num_epoch,
lr=0.0,
loss=loss.detach().cpu().item(),
speed_stats=speed_stats,
stats=stats,
writer=writer,
tag="val",
)
else:
if self.use_ddp or self.use_fsdp:
iterator_stop.fill_(1)
dist.all_reduce(iterator_stop, dist.ReduceOp.SUM)
if kwargs.get("step_in_epoch", None) is None:
ckpt_name = f"model.pt.ep{epoch}"
else:
ckpt_name = f'model.pt.ep{epoch}.{kwargs.get("step_in_epoch")}'
self.val_acc_step_or_epoch[ckpt_name] = self.val_acc_avg
self.val_loss_step_or_epoch[ckpt_name] = self.val_loss_avg
model.train()
if self.use_ddp or self.use_fsdp:
dist.barrier()
iterator_stop = torch.tensor(0).to(self.device)
def log(
self,
epoch=0,
batch_idx=0,
step_in_epoch=0,
batch_num_epoch=-1,
lr=0.0,
loss=0.0,
speed_stats=None,
stats=None,
writer=None,
tag="train",
data_split_i=0,
data_split_num=1,
log_step=None,
**kwargs,
):
if (batch_idx + 1) % self.log_interval == 0:
batch_idx = log_step if log_step is not None else batch_idx
gpu_info = (
"GPU, memory: usage: {:.3f} GB, "
"peak: {:.3f} GB, "
"cache: {:.3f} GB, "
"cache_peak: {:.3f} GB".format(
torch.cuda.memory_allocated() / 1024 / 1024 / 1024,
torch.cuda.max_memory_allocated() / 1024 / 1024 / 1024,
torch.cuda.memory_reserved() / 1024 / 1024 / 1024,
torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024,
)
)
loss_avg_epoch = getattr(self, f"{tag}_loss_avg")
acc_avg_epoch = getattr(self, f"{tag}_acc_avg")
description = (
f"{tag}, "
f"rank: {self.rank}, "
f"epoch: {epoch}/{self.max_epoch}, "
f"data_slice: {data_split_i}/{data_split_num}, "
f"step_in_slice: {batch_idx + 1}/{batch_num_epoch}, step_in_epoch: {step_in_epoch}, total step: {self.batch_total}, "
f"(loss_avg_rank: {loss:.3f}), "
f"(loss_avg_slice: {loss_avg_epoch:.3f}), "
f"(ppl_avg_slice: {math.exp(loss_avg_epoch):.3e}), "
f"(acc_avg_slice: {acc_avg_epoch:.3f}), "
f"(lr: {lr:.3e}), "
f"{[(k, round(v.detach().cpu().item(), 3)) for k, v in stats.items()]}, "
f"{speed_stats}, "
f"{gpu_info}"
)
logging.info(description)
description_dict = {
f"rank{self.rank}_loss/{tag}": loss,
f"rank{self.rank}_lr/{tag}": lr,
}
if writer is not None:
writer.add_scalar(f"rank{self.rank}_loss/{tag}", loss, self.batch_total)
writer.add_scalar(f"rank{self.rank}_lr/{tag}", lr, self.batch_total)
for key, var in stats.items():
writer.add_scalar(
f"stats_rank{self.rank}_{key}/{tag}", var.item(), self.batch_total
)
description_dict[f"stats_rank{self.rank}_{key}/{tag}"] = var.item()
for key, var in speed_stats.items():
writer.add_scalar(
f"stats_rank{self.rank}_{key}/{tag}", eval(var), self.batch_total
)
description_dict[f"stats_rank{self.rank}_{key}/{tag}"] = eval(var)
if self.use_wandb and wandb is not None:
wandb.log(
description_dict,
setp=self.batch_total,
)
def close(self, writer=None):
if self.use_ddp or self.use_fsdp:
dist.barrier()
if writer is not None:
writer.close()
if self.use_ddp or self.use_fsdp:
torch.distributed.destroy_process_group()

View File

@@ -0,0 +1,997 @@
import math
import os
import time
import torch
import logging
from tqdm import tqdm
from datetime import datetime
import torch.distributed as dist
from torch.cuda.amp import autocast, GradScaler
from contextlib import nullcontext, contextmanager
from pathlib import Path
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from funasr.train_utils.device_funcs import to_device
from funasr.train_utils.recursive_op import recursive_average
from funasr.train_utils.average_nbest_models import average_checkpoints
from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
import funasr.utils.misc as misc_utils
try:
import wandb
except:
wandb = None
@contextmanager
def maybe_autocast(dtype=None, use_deepspeed=False):
if use_deepspeed:
with torch.cuda.amp.autocast(enabled=True, dtype=dtype, cache_enabled=False):
yield
else:
if dtype == torch.float16 or dtype == torch.bfloat16:
yield
# with autocast(enabled=True, dtype=dtype):
# yield
else:
yield
class Trainer:
"""
A simple trainer class for training a PyTorch model, saving checkpoints at the end of each epoch,
and optionally resuming from a saved checkpoint.
Attributes:
max_epoch (int): Maximum number of epochs for training.
model (torch.nn.Module): The model to be trained.
optim (torch.optim.Optimizer): The optimizer to use for training.
scheduler (torch.optim.lr_scheduler._LRScheduler): The learning rate scheduler.
dataloader_train (torch.utils.data.DataLoader): DataLoader for the training dataset.
dataloader_val (torch.utils.data.DataLoader): DataLoader for the validation dataset.
output_dir (str): Directory where model checkpoints will be saved.
resume (str, optional): Path to a checkpoint to resume training from.
"""
def __init__(
self,
rank=0,
local_rank=0,
world_size=1,
use_ddp: bool = False,
use_fsdp: bool = False,
use_fp16: bool = False,
use_bf16: bool = False,
use_deepspeed: bool = False,
output_dir: str = "./",
**kwargs,
):
"""
Initializes the Trainer class with the model, optimizer, scheduler, dataloaders, and other settings.
Args:
model (torch.nn.Module): The model to be trained.
optim (torch.optim.Optimizer): The optimizer to use for training.
scheduler (torch.optim.lr_scheduler._LRScheduler): The learning rate scheduler.
dataloader_train (torch.utils.data.DataLoader): The DataLoader for the training dataset.
dataloader_val (torch.utils.data.DataLoader): The DataLoader for the validation dataset.
**kwargs: Additional keyword arguments:
max_epoch (int): The maximum number of epochs for training.
output_dir (str): The directory where model checkpoints will be saved. Default is './'.
resume (str, optional): The file path to a checkpoint to resume training from.
"""
self.rank = rank
self.local_rank = local_rank
self.world_size = world_size
self.use_ddp = use_ddp
self.use_fsdp = use_fsdp
self.device = kwargs.get("device", "cuda")
self.output_dir = output_dir
if not os.path.exists(self.output_dir):
os.makedirs(self.output_dir, exist_ok=True)
self.resume = kwargs.get("resume", True)
self.start_epoch = 0
self.max_epoch = kwargs.get("max_epoch", 100)
# self.kwargs = kwargs
self.log_interval = kwargs.get("log_interval", 50)
self.batch_total = 0
self.dtype = torch.float32
self.use_fp16 = use_fp16
self.use_bf16 = use_bf16
if self.use_fp16:
self.dtype = torch.float16
if self.use_bf16:
self.dtype = torch.bfloat16
self.save_checkpoint_interval = kwargs.get("save_checkpoint_interval", 5000)
self.validate_interval = kwargs.get("validate_interval", 5000)
self.keep_nbest_models = kwargs.get("keep_nbest_models", 500)
self.avg_keep_nbest_models_type = kwargs.get("avg_keep_nbest_models_type", "acc")
self.avg_nbest_model = kwargs.get("avg_nbest_model", 10)
self.accum_grad = kwargs.get("accum_grad", 1)
self.grad_clip = kwargs.get("grad_clip", 10.0)
self.grad_clip_type = kwargs.get("grad_clip_type", 2.0)
self.train_acc_avg = 0.0
self.train_loss_avg = 0.0
self.val_acc_avg = 0.0
self.val_loss_avg = 0.0
self.best_acc_idx = 0
self.saved_ckpts = {}
self.step_or_epoch = -1
self.best_step_or_epoch = ""
self.val_acc_step_or_eoch = {}
self.val_loss_step_or_eoch = {}
self.reset_gpu_cache = kwargs.get("reset_gpu_cache", False)
self.start_data_split_i = 0
self.start_step = 0
self.step_in_epoch = 0
self.use_wandb = kwargs.get("use_wandb", False)
if self.use_wandb:
wandb.login(key=kwargs.get("wandb_token"))
wandb.init(
config=kwargs,
project=kwargs.get("wandb_project", "my_project"),
entity=kwargs.get("wandb_team", "my_team"),
name=kwargs.get("wandb_exp_name", "my_exp"),
dir=output_dir,
job_type="training",
reinit=True,
)
tensorboard_dir = os.path.join(output_dir, "tensorboard")
os.makedirs(tensorboard_dir, exist_ok=True)
try:
from tensorboardX import SummaryWriter
self.writer = SummaryWriter(tensorboard_dir) # if trainer.rank == 0 else None
except:
self.writer = None
self.use_deepspeed = use_deepspeed
self.deepspeed_config = kwargs.get("deepspeed_config", "")
excludes = kwargs.get("excludes", None)
if excludes is not None:
if isinstance(excludes, str):
excludes = excludes.split(",")
self.excludes = excludes
effective_save_name_excludes = kwargs.get("effective_save_name_excludes", None)
if effective_save_name_excludes is not None:
if isinstance(effective_save_name_excludes, str):
effective_save_name_excludes = effective_save_name_excludes.split(",")
self.effective_save_name_excludes = effective_save_name_excludes
def save_checkpoint(
self,
epoch,
step=None,
model=None,
optim=None,
scheduler=None,
scaler=None,
step_in_epoch=None,
**kwargs,
):
"""
Saves a checkpoint containing the model's state, the optimizer's state,
and the scheduler's state at the end of the given epoch. This method is
intended to be called at the end of each epoch to save the training progress.
Args:
epoch (int): The epoch number at which the checkpoint is being saved.
"""
if self.use_ddp or self.use_fsdp:
dist.barrier()
step_in_epoch = None if step is None else step_in_epoch
if self.use_deepspeed:
logging.info(f"Save checkpoint: {epoch}, rank: {self.local_rank}\n")
# self.step_or_epoch += 1
state = {
"epoch": epoch,
# "state_dict": model.state_dict(),
# "optimizer": optim.state_dict(),
# "scheduler": scheduler.state_dict(),
"saved_ckpts": self.saved_ckpts,
"val_acc_step_or_eoch": self.val_acc_step_or_eoch,
"val_loss_step_or_eoch": self.val_loss_step_or_eoch,
"best_step_or_epoch": self.best_step_or_epoch,
"avg_keep_nbest_models_type": self.avg_keep_nbest_models_type,
"step": step,
"step_in_epoch": step_in_epoch,
"data_split_i": kwargs.get("data_split_i", 0),
"data_split_num": kwargs.get("data_split_num", 1),
"batch_total": self.batch_total,
"train_loss_avg": kwargs.get("train_loss_avg", 0),
"train_acc_avg": kwargs.get("train_acc_avg", 0),
}
step = step_in_epoch
if hasattr(model, "module"):
state["state_dict"] = model.module.state_dict()
if scaler:
state["scaler_state"] = scaler.state_dict()
# Create output directory if it does not exist
os.makedirs(self.output_dir, exist_ok=True)
if step is None:
ckpt_name = f"model.pt.ep{epoch}"
else:
ckpt_name = f"model.pt.ep{epoch}.{step}"
filename = os.path.join(self.output_dir, ckpt_name)
# torch.save(state, filename)
with torch.no_grad():
model.save_checkpoint(save_dir=self.output_dir, tag=ckpt_name, client_state=state)
logging.info(f"\nCheckpoint saved to {filename}\n")
latest = Path(os.path.join(self.output_dir, f"model.pt"))
# torch.save(state, latest)
with torch.no_grad():
model.save_checkpoint(save_dir=self.output_dir, tag=f"model.pt", client_state=state)
if self.best_step_or_epoch == "":
self.best_step_or_epoch = ckpt_name
if self.avg_keep_nbest_models_type == "acc":
if (
self.val_acc_step_or_eoch[ckpt_name]
>= self.val_acc_step_or_eoch[self.best_step_or_epoch]
):
self.best_step_or_epoch = ckpt_name
best_ckpt = Path(os.path.join(self.output_dir, f"model.pt.best"))
# torch.save(state, best_ckpt)
with torch.no_grad():
model.save_checkpoint(
save_dir=self.output_dir, tag=f"model.pt.best", client_state=state
)
logging.info(
f"Update best acc: {self.val_acc_step_or_eoch[self.best_step_or_epoch]:.4f}, {best_ckpt}"
)
else:
logging.info(
f"No improvement in acc: {self.val_acc_step_or_eoch[ckpt_name]:.4f} < {self.val_acc_step_or_eoch[self.best_step_or_epoch]:.4f}, {os.path.join(self.output_dir, self.best_step_or_epoch)}"
)
elif self.avg_keep_nbest_models_type == "loss":
if (
self.val_loss_step_or_eoch[ckpt_name]
<= self.val_loss_step_or_eoch[self.best_step_or_epoch]
):
self.best_step_or_epoch = ckpt_name
best_ckpt = Path(os.path.join(self.output_dir, f"model.pt.best"))
# torch.save(state, best_ckpt)
with torch.no_grad():
model.save_checkpoint(
save_dir=self.output_dir, tag=f"model.pt.best", client_state=state
)
logging.info(
f"Update best loss: {self.val_loss_step_or_eoch[self.best_step_or_epoch]:.4f}, {best_ckpt}"
)
else:
logging.info(
f"No improvement in loss: {self.val_loss_step_or_eoch[ckpt_name]:.4f} > {self.val_loss_step_or_eoch[self.best_step_or_epoch]:.4f}, {os.path.join(self.output_dir, self.best_step_or_epoch)}"
)
else:
print("Undo")
self.saved_ckpts[ckpt_name] = getattr(
self, f"val_{self.avg_keep_nbest_models_type}_step_or_eoch"
)[ckpt_name]
if self.keep_nbest_models > 0:
if len(self.saved_ckpts) > self.keep_nbest_models:
if self.avg_keep_nbest_models_type == "acc":
key = min(self.saved_ckpts, key=self.saved_ckpts.get)
else:
key = max(self.saved_ckpts, key=self.saved_ckpts.get)
if key in self.saved_ckpts:
del self.saved_ckpts[key]
filename = os.path.join(self.output_dir, key)
logging.info(f"Delete: {filename}")
if os.path.exists(filename):
# os.remove(filename)
misc_utils.smart_remove(filename)
elif self.use_fsdp:
pass
elif self.rank == 0:
logging.info(
f"Save checkpoint: {epoch}, rank: {self.rank}, local_rank: {self.local_rank}\n"
)
# self.step_or_epoch += 1
state = {
"epoch": epoch,
"optimizer": optim.state_dict(),
"scheduler": scheduler.state_dict(),
"saved_ckpts": self.saved_ckpts,
"val_acc_step_or_eoch": self.val_acc_step_or_eoch,
"val_loss_step_or_eoch": self.val_loss_step_or_eoch,
"best_step_or_epoch": self.best_step_or_epoch,
"avg_keep_nbest_models_type": self.avg_keep_nbest_models_type,
"step": step,
"step_in_epoch": step_in_epoch,
"data_split_i": kwargs.get("data_split_i", 0),
"data_split_num": kwargs.get("data_split_num", 1),
"batch_total": self.batch_total,
"train_loss_avg": kwargs.get("train_loss_avg", 0),
"train_acc_avg": kwargs.get("train_acc_avg", 0),
}
step = step_in_epoch
if hasattr(model, "module"):
state_dict = model.module.state_dict()
else:
state_dict = model.state_dict()
if self.effective_save_name_excludes is not None:
logging.info(f"effective_save_name_excludes: {self.effective_save_name_excludes}")
dst_state_dict = {}
for k in state_dict.keys():
for k_ex in self.effective_save_name_excludes:
k_tmp = k.replace("module.", "")
if k.startswith(k_ex):
logging.info(f"key: {k} matching: {k_ex}, not save it")
break
else:
dst_state_dict[k] = state_dict[k]
state["state_dict"] = dst_state_dict
else:
state["state_dict"] = state_dict
if scaler:
state["scaler_state"] = scaler.state_dict()
# Create output directory if it does not exist
os.makedirs(self.output_dir, exist_ok=True)
if step is None:
ckpt_name = f"model.pt.ep{epoch}"
else:
ckpt_name = f"model.pt.ep{epoch}.{step}"
filename = os.path.join(self.output_dir, ckpt_name)
torch.save(state, filename)
logging.info(f"\nCheckpoint saved to {filename}\n")
latest = Path(os.path.join(self.output_dir, f"model.pt"))
torch.save(state, latest)
if self.best_step_or_epoch == "":
self.best_step_or_epoch = ckpt_name
if self.avg_keep_nbest_models_type == "acc":
if (
self.val_acc_step_or_eoch[ckpt_name]
>= self.val_acc_step_or_eoch[self.best_step_or_epoch]
):
self.best_step_or_epoch = ckpt_name
best_ckpt = Path(os.path.join(self.output_dir, f"model.pt.best"))
torch.save(state, best_ckpt)
logging.info(
f"Update best acc: {self.val_acc_step_or_eoch[self.best_step_or_epoch]:.4f}, {best_ckpt}"
)
else:
logging.info(
f"No improvement in acc: {self.val_acc_step_or_eoch[ckpt_name]:.4f} < {self.val_acc_step_or_eoch[self.best_step_or_epoch]:.4f}, {os.path.join(self.output_dir, self.best_step_or_epoch)}"
)
elif self.avg_keep_nbest_models_type == "loss":
if (
self.val_loss_step_or_eoch[ckpt_name]
<= self.val_loss_step_or_eoch[self.best_step_or_epoch]
):
self.best_step_or_epoch = ckpt_name
best_ckpt = Path(os.path.join(self.output_dir, f"model.pt.best"))
torch.save(state, best_ckpt)
logging.info(
f"Update best loss: {self.val_loss_step_or_eoch[self.best_step_or_epoch]:.4f}, {best_ckpt}"
)
else:
logging.info(
f"No improvement in loss: {self.val_loss_step_or_eoch[ckpt_name]:.4f} > {self.val_loss_step_or_eoch[self.best_step_or_epoch]:.4f}, {os.path.join(self.output_dir, self.best_step_or_epoch)}"
)
else:
print("Undo")
self.saved_ckpts[ckpt_name] = getattr(
self, f"val_{self.avg_keep_nbest_models_type}_step_or_eoch"
)[ckpt_name]
if self.keep_nbest_models > 0:
if len(self.saved_ckpts) > self.keep_nbest_models:
if self.avg_keep_nbest_models_type == "acc":
key = min(self.saved_ckpts, key=self.saved_ckpts.get)
else:
key = max(self.saved_ckpts, key=self.saved_ckpts.get)
if key in self.saved_ckpts:
del self.saved_ckpts[key]
filename = os.path.join(self.output_dir, key)
logging.info(f"Delete: {filename}")
if os.path.exists(filename):
# os.remove(filename)
misc_utils.smart_remove(filename)
if self.use_ddp or self.use_fsdp:
dist.barrier()
def resume_checkpoint(
self,
model=None,
optim=None,
scheduler=None,
scaler=None,
):
"""
Resumes training from a checkpoint at the given file path.
Loads the model's state, the optimizer's state, and the scheduler's state.
Args:
resume_path (str): The file path to the checkpoint to resume from.
"""
if self.resume:
if self.use_deepspeed:
ckpt = os.path.join(self.output_dir, "model.pt")
if os.path.exists(ckpt):
_, checkpoint = model.load_checkpoint(self.output_dir, "model.pt")
self.start_epoch = checkpoint["epoch"]
self.saved_ckpts = checkpoint["saved_ckpts"]
self.val_acc_step_or_eoch = (
checkpoint["val_acc_step_or_eoch"]
if "val_acc_step_or_eoch" in checkpoint
else {}
)
self.val_loss_step_or_eoch = (
checkpoint["val_loss_step_or_eoch"]
if "val_loss_step_or_eoch" in checkpoint
else {}
)
self.best_step_or_epoch = (
checkpoint["best_step_or_epoch"]
if "best_step_or_epoch" in checkpoint
else ""
)
self.start_data_split_i = (
checkpoint["data_split_i"] if "data_split_i" in checkpoint else 0
)
self.batch_total = (
checkpoint["batch_total"] if "batch_total" in checkpoint else 0
)
self.start_step = checkpoint["step"] if "step" in checkpoint else 0
self.start_step = 0 if self.start_step is None else self.start_step
self.step_in_epoch = (
checkpoint["step_in_epoch"] if "step_in_epoch" in checkpoint else 0
)
self.step_in_epoch = 0 if self.step_in_epoch is None else self.step_in_epoch
print(checkpoint["train_acc_avg"])
self.train_acc_avg = (
checkpoint["train_acc_avg"] if "train_acc_avg" in checkpoint else 0
)
self.train_loss_avg = (
checkpoint["train_loss_avg"] if "train_loss_avg" in checkpoint else 0
)
model.to(self.device)
print(f"Checkpoint loaded successfully from '{ckpt}'")
else:
print(f"No checkpoint found at '{ckpt}', does not resume status!")
else:
ckpt = os.path.join(self.output_dir, "model.pt")
if os.path.isfile(ckpt):
checkpoint = torch.load(ckpt, map_location="cpu")
self.start_epoch = checkpoint["epoch"]
# self.model.load_state_dict(checkpoint['state_dict'])
src_state = checkpoint["state_dict"]
dst_state = model.state_dict()
for k in dst_state.keys():
excludes_flag = False
if self.excludes is not None:
for k_ex in self.excludes:
k_tmp = k.replace("module.", "")
if k_tmp.startswith(k_ex):
logging.info(f"key: {k} matching: {k_ex}, excluded")
excludes_flag = True
break
if excludes_flag:
continue
if not k.startswith("module.") and "module." + k in src_state.keys():
k_ddp = "module." + k
elif k.startswith("module.") and "module." + k not in src_state.keys():
k_ddp = k.replace("module.", "", 1)
else:
k_ddp = k
if k_ddp in src_state.keys():
dst_state[k] = src_state[k_ddp]
else:
print(f"Miss key in ckpt: model: {k}, ckpt: {k_ddp}")
model.load_state_dict(dst_state)
optim.load_state_dict(checkpoint["optimizer"])
scheduler.load_state_dict(checkpoint["scheduler"])
if scaler is not None and "scaler_state" in checkpoint:
scaler.load_state_dict(checkpoint["scaler_state"])
self.saved_ckpts = checkpoint["saved_ckpts"]
self.val_acc_step_or_eoch = (
checkpoint["val_acc_step_or_eoch"]
if "val_acc_step_or_eoch" in checkpoint
else {}
)
self.val_loss_step_or_eoch = (
checkpoint["val_loss_step_or_eoch"]
if "val_loss_step_or_eoch" in checkpoint
else {}
)
self.best_step_or_epoch = (
checkpoint["best_step_or_epoch"]
if "best_step_or_epoch" in checkpoint
else ""
)
self.start_data_split_i = (
checkpoint["data_split_i"] if "data_split_i" in checkpoint else 0
)
self.batch_total = (
checkpoint["batch_total"] if "batch_total" in checkpoint else 0
)
self.start_step = checkpoint["step"] if "step" in checkpoint else 0
self.start_step = 0 if self.start_step is None else self.start_step
self.step_in_epoch = (
checkpoint["step_in_epoch"] if "step_in_epoch" in checkpoint else 0
)
self.step_in_epoch = 0 if self.step_in_epoch is None else self.step_in_epoch
print(checkpoint["train_acc_avg"])
self.train_acc_avg = (
checkpoint["train_acc_avg"] if "train_acc_avg" in checkpoint else 0
)
self.train_loss_avg = (
checkpoint["train_loss_avg"] if "train_loss_avg" in checkpoint else 0
)
model.to(self.device)
print(f"Checkpoint loaded successfully from '{ckpt}'")
else:
print(f"No checkpoint found at '{ckpt}', does not resume status!")
if self.use_ddp or self.use_fsdp:
dist.barrier()
def train_epoch(
self,
model=None,
optim=None,
scheduler=None,
scaler=None,
dataloader_train=None,
dataloader_val=None,
epoch=None,
**kwargs,
):
"""
Defines the training process for a single epoch with gradient accumulation.
Args:
epoch (int): The current epoch number.
"""
if self.use_ddp or self.use_fsdp or self.use_deepspeed:
dist.barrier()
logging.info(f"Train epoch: {epoch}, rank: {self.rank}\n")
model.train()
# Set the number of steps for gradient accumulation
accum_grad = self.accum_grad
# Initialize the gradient accumulation
optim.zero_grad()
speed_stats = {}
iterator_stop = torch.tensor(0).to(self.device)
dataloader_train.batch_sampler.set_epoch(epoch)
time_beg = time.perf_counter()
time5 = time_beg
for batch_idx, batch in enumerate(dataloader_train):
self.batch_total += 1
self.step_in_epoch += 1
loss_dict = {
"speed_stats": {},
"epoch": epoch,
"batch_idx": batch_idx,
"data_split_i": kwargs.get("data_split_i", 0),
"data_split_num": kwargs.get("data_split_num", 1),
"log_step": batch_idx + kwargs.get("start_step", 0),
"batch_total": self.batch_total,
"step_in_epoch": self.step_in_epoch,
}
time1 = time.perf_counter()
loss_dict["speed_stats"]["data_load"] = f"{time1-time_beg:0.3f}"
batch = to_device(batch, self.device)
my_context = nullcontext
if self.use_ddp or self.use_fsdp:
my_context = model.no_sync if batch_idx % accum_grad != 0 else my_context
with my_context():
time2 = time.perf_counter()
self.forward_step(model, batch, loss_dict=loss_dict)
time3 = time.perf_counter()
loss_dict["speed_stats"]["forward_time"] = f"{time3 - time2:0.3f}"
self.backward_step(model, scaler, loss_dict=loss_dict)
time4 = time.perf_counter()
loss_dict["speed_stats"]["backward_time"] = f"{time4 - time3:0.3f}"
self.update_step(model, optim, scheduler, scaler, loss_dict=loss_dict)
total_time = f"{(time.perf_counter() - time5):0.3f}"
time5 = time.perf_counter()
loss_dict["speed_stats"]["optim_time"] = f"{time5 - time4:0.3f}"
loss_dict["speed_stats"]["total_time"] = total_time
loss_dict["lr"] = scheduler.get_last_lr()[0]
loss_dict["batch_num_epoch"] = len(dataloader_train)
self.train_loss_avg = (
self.train_loss_avg * batch_idx + loss_dict["loss"].detach().cpu().item()
) / (batch_idx + 1)
if "acc" in loss_dict["stats"]:
self.train_acc_avg = (
self.train_acc_avg * batch_idx + loss_dict["stats"]["acc"].detach().cpu().item()
) / (batch_idx + 1)
self.log(loss_dict, tag="train")
if self.step_in_epoch % self.validate_interval == 0:
self.validate_epoch(
model=model,
dataloader_val=dataloader_val,
epoch=epoch,
writer=self.writer,
step=batch_idx + 1,
step_in_epoch=self.step_in_epoch,
)
if self.step_in_epoch % self.save_checkpoint_interval == 0:
self.save_checkpoint(
epoch,
model=model,
optim=optim,
scheduler=scheduler,
scaler=scaler,
step=batch_idx + 1,
step_in_epoch=self.step_in_epoch,
data_split_i=kwargs.get("data_split_i", 0),
data_split_num=kwargs.get("data_split_num", 1),
train_loss_avg=self.train_loss_avg,
train_acc_avg=self.train_acc_avg,
)
time_beg = time.perf_counter()
if self.use_ddp or self.use_fsdp or self.use_deepspeed:
train_loss_avg = torch.tensor(self.train_loss_avg, dtype=torch.float32).to(self.device)
train_acc_avg = torch.tensor(self.train_acc_avg, dtype=torch.float32).to(self.device)
dist.all_reduce(train_loss_avg, op=dist.ReduceOp.SUM)
dist.all_reduce(train_acc_avg, op=dist.ReduceOp.SUM)
self.train_loss_avg = train_loss_avg.detach().cpu().item() / self.world_size
self.train_acc_avg = train_acc_avg.detach().cpu().item() / self.world_size
def forward_step(self, model, batch, loss_dict={}):
with maybe_autocast(dtype=self.dtype, use_deepspeed=self.use_deepspeed):
retval = model(**batch)
loss, stats, weight = retval
stats = {k: v for k, v in stats.items() if v is not None}
loss_dict["loss"] = loss
loss_dict["stats"] = stats
loss_dict["weight"] = weight
def backward_step(self, model, scaler, loss_dict={}):
loss = loss_dict["loss"]
if self.use_deepspeed:
scaled_loss = model.backward(loss)
else:
loss = loss / self.accum_grad
if self.use_fp16 or self.use_bf16:
scaler.scale(loss).backward()
else:
loss.backward()
def update_step(self, model, optim, scheduler, scaler, loss_dict=None):
batch_idx = loss_dict["batch_idx"]
if self.use_deepspeed:
model.step()
else:
if (batch_idx + 1) % self.accum_grad == 0:
# Perform gradient clipping if it is set
if self.grad_clip > 0:
grad_norm = torch.nn.utils.clip_grad_norm_(
model.parameters(),
max_norm=self.grad_clip,
norm_type=self.grad_clip_type,
)
if not torch.isfinite(grad_norm):
logging.warning(
f"The grad norm is {grad_norm}. Skipping updating the model."
)
optim.zero_grad() # Reset gradients
return
# Execute an optimization step (update model parameters)
if self.use_ddp or self.use_fsdp:
dist.barrier()
if self.use_fp16 or self.use_bf16:
scaler.step(optim)
scaler.update()
else:
optim.step()
scheduler.step()
# Clear gradients for the next accumulation stage
optim.zero_grad(set_to_none=True)
def validate_epoch(
self,
model=None,
dataloader_val=None,
epoch=None,
writer=None,
**kwargs,
):
"""
Defines the validation process for a single epoch.
Should be implemented with the actual model validation steps.
Args:
epoch (int): The current epoch number.
"""
if self.use_ddp or self.use_fsdp or self.use_deepspeed:
dist.barrier()
logging.info(f"Validate epoch: {epoch}, rank: {self.rank}\n")
model.eval()
with torch.no_grad():
speed_stats = {}
time_beg = time.perf_counter()
time5 = time_beg
dataloader_val.batch_sampler.set_epoch(epoch)
for batch_idx, batch in enumerate(dataloader_val):
loss_dict = {
"speed_stats": {},
"epoch": epoch,
"batch_idx": batch_idx,
"data_split_i": kwargs.get("data_split_i", 0),
"data_split_num": kwargs.get("data_split_num", 1),
"log_step": batch_idx + kwargs.get("start_step", 0),
"batch_total": batch_idx + 1,
"step_in_epoch": batch_idx + 1,
"lr": 0.0,
}
time1 = time.perf_counter()
loss_dict["speed_stats"]["data_load"] = f"{time1 - time_beg:0.3f}"
batch = to_device(batch, self.device)
time2 = time.perf_counter()
self.forward_step(model, batch, loss_dict=loss_dict)
time3 = time.perf_counter()
loss_dict["speed_stats"]["forward_time"] = f"{time3 - time2:0.3f}"
total_time = f"{(time.perf_counter() - time5):0.3f}"
time5 = time.perf_counter()
loss_dict["speed_stats"]["total_time"] = total_time
loss_dict["batch_num_epoch"] = len(dataloader_val)
self.log(loss_dict, tag="val")
time_beg = time.perf_counter()
self.val_loss_avg = (
self.val_loss_avg * batch_idx + loss_dict["loss"].detach().cpu().item()
) / (batch_idx + 1)
if "acc" in loss_dict["stats"]:
self.val_acc_avg = (
self.val_acc_avg * batch_idx
+ loss_dict["stats"]["acc"].detach().cpu().item()
) / (batch_idx + 1)
if self.use_ddp or self.use_fsdp or self.use_deepspeed:
val_loss_avg = torch.tensor(self.val_loss_avg, dtype=torch.float32).to(self.device)
val_acc_avg = torch.tensor(self.val_acc_avg, dtype=torch.float32).to(self.device)
dist.all_reduce(val_loss_avg, op=dist.ReduceOp.SUM)
dist.all_reduce(val_acc_avg, op=dist.ReduceOp.SUM)
self.val_loss_avg = val_loss_avg.detach().cpu().item() / self.world_size
self.val_acc_avg = val_acc_avg.detach().cpu().item() / self.world_size
if kwargs.get("step_in_epoch", None) is None:
ckpt_name = f"model.pt.ep{epoch}"
else:
ckpt_name = f'model.pt.ep{epoch}.{kwargs.get("step_in_epoch")}'
self.val_acc_step_or_eoch[ckpt_name] = self.val_acc_avg
self.val_loss_step_or_eoch[ckpt_name] = self.val_loss_avg
if self.use_ddp or self.use_fsdp or self.use_deepspeed:
dist.barrier()
model.train()
def log(
self,
loss_dict: dict = None,
tag="train",
**kwargs,
):
loss = loss_dict["loss"].detach().cpu().item()
epoch = loss_dict["epoch"]
batch_idx = loss_dict["batch_idx"]
step_in_epoch = loss_dict["step_in_epoch"]
batch_total = loss_dict["batch_total"]
batch_num_epoch = loss_dict["batch_num_epoch"]
lr = loss_dict["lr"]
speed_stats = loss_dict["speed_stats"]
stats = loss_dict["stats"]
data_split_i = loss_dict["data_split_i"]
data_split_num = loss_dict["data_split_num"]
log_step = loss_dict.get("log_step", None)
if (batch_idx + 1) % self.log_interval == 0:
batch_idx = log_step if log_step is not None else batch_idx
gpu_info = (
"GPU, memory: usage: {:.3f} GB, "
"peak: {:.3f} GB, "
"cache: {:.3f} GB, "
"cache_peak: {:.3f} GB".format(
torch.cuda.memory_allocated() / 1024 / 1024 / 1024,
torch.cuda.max_memory_allocated() / 1024 / 1024 / 1024,
torch.cuda.memory_reserved() / 1024 / 1024 / 1024,
torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024,
)
)
loss_avg_epoch = getattr(self, f"{tag}_loss_avg")
acc_avg_epoch = getattr(self, f"{tag}_acc_avg")
description = (
f"{tag}, "
f"rank: {self.rank}, "
f"epoch: {epoch}/{self.max_epoch}, "
f"data_slice: {data_split_i}/{data_split_num}, "
f"step_in_slice: {batch_idx + 1}/{batch_num_epoch}, step_in_epoch: {step_in_epoch}, total step: {batch_total}, "
f"(loss_avg_rank: {loss:.3f}), "
f"(loss_avg_slice: {loss_avg_epoch:.3f}), "
f"(ppl_avg_slice: {math.exp(loss_avg_epoch):.3e}), "
f"(acc_avg_slice: {acc_avg_epoch:.3f}), "
f"(lr: {lr:.3e}), "
f"{[(k, round(v.detach().cpu().item(), 3)) for k, v in stats.items()]}, "
f"{speed_stats}, "
f"{gpu_info}"
)
logging.info(description)
description_dict = {
f"rank{self.rank}_loss/{tag}": loss,
f"rank{self.rank}_lr/{tag}": lr,
}
writer = self.writer
if writer is not None:
writer.add_scalar(f"rank{self.rank}_loss/{tag}", loss, batch_total)
writer.add_scalar(f"rank{self.rank}_lr/{tag}", lr, batch_total)
for key, var in stats.items():
writer.add_scalar(f"stats_rank{self.rank}_{key}/{tag}", var.item(), batch_total)
description_dict[f"stats_rank{self.rank}_{key}/{tag}"] = var.item()
for key, var in speed_stats.items():
writer.add_scalar(f"stats_rank{self.rank}_{key}/{tag}", eval(var), batch_total)
description_dict[f"stats_rank{self.rank}_{key}/{tag}"] = eval(var)
if self.use_wandb and wandb is not None:
wandb.log(
description_dict,
setp=batch_total,
)
def close(self, writer=None):
if self.use_ddp or self.use_fsdp:
dist.barrier()
if writer is not None:
writer.close()
if self.use_ddp or self.use_fsdp:
torch.distributed.destroy_process_group()
def warp_model(self, model, **kwargs):
if self.use_deepspeed:
from deepspeed.runtime.zero.stage_1_and_2 import (
estimate_zero2_model_states_mem_needs_all_live,
)
from deepspeed.runtime.zero.stage3 import estimate_zero3_model_states_mem_needs_all_live
from deepspeed.utils.zero_to_fp32 import convert_zero_checkpoint_to_fp32_state_dict
local_world_size = int(os.environ.get("LOCAL_WORLD_SIZE", 1))
world_size = int(os.environ.get("WORLD_SIZE", 1))
# NOTE(xcsong): look in detail how the memory estimator API works:
# https://deepspeed.readthedocs.io/en/latest/memory.html#discussion
if int(os.environ.get("RANK", 0)) == 0:
logging.info("Estimating model states memory needs (zero2)...")
estimate_zero2_model_states_mem_needs_all_live(
model,
num_gpus_per_node=local_world_size,
num_nodes=world_size // local_world_size,
)
logging.info("Estimating model states memory needs (zero3)...")
estimate_zero3_model_states_mem_needs_all_live(
model,
num_gpus_per_node=local_world_size,
num_nodes=world_size // local_world_size,
)
device = None # Init device later
pass # Init DeepSpeed later
elif self.use_ddp:
local_rank = int(os.environ.get("LOCAL_RANK", 0))
model = model.cuda(local_rank)
model = DDP(
model,
device_ids=[local_rank],
find_unused_parameters=kwargs.get("train_conf", {}).get(
"find_unused_parameters", False
),
)
else:
model = model.to(device=kwargs.get("device", "cuda"))
return model
def warp_optim_scheduler(self, model, **kwargs):
from funasr.optimizers import optim_classes
from funasr.schedulers import scheduler_classes
from omegaconf import OmegaConf, DictConfig
import json
# optim
logging.info("Build optim")
optim = kwargs.get("optim", "adam")
assert optim in optim_classes
optim_class = optim_classes.get(optim)
optim = optim_class(model.parameters(), **kwargs.get("optim_conf"))
# scheduler
logging.info("Build scheduler")
scheduler = kwargs.get("scheduler", "warmuplr")
assert scheduler in scheduler_classes
scheduler_class = scheduler_classes.get(scheduler)
scheduler = scheduler_class(optim, **kwargs.get("scheduler_conf"))
if self.use_deepspeed:
import deepspeed
args = OmegaConf.create({"deepspeed_config": self.deepspeed_config})
with open(self.deepspeed_config, "r") as fin:
ds_configs = json.load(fin)
if "bf16" in ds_configs and ds_configs["bf16"]["enabled"]:
self.dtype = torch.bfloat16
if "fp16" in ds_configs and ds_configs["fp16"]["enabled"]:
self.dtype = torch.float16
if "optimizer" in ds_configs:
# NOTE(xcsong): Disable custom optimizer if it is set in ds_config,
# extremely useful when enable cpu_offload, DeepspeedCpuAdam
# could be 4~5x faster than torch native adam
optim = None
if "scheduler" in ds_configs:
scheduler = None
else:
def scheduler(opt):
return scheduler_class(opt, **kwargs.get("scheduler_conf"))
model, optimizer, _, scheduler = deepspeed.initialize(
args=args,
model=model,
optimizer=optim,
lr_scheduler=scheduler,
model_parameters=model.parameters(),
)
return model, optim, scheduler