mirror of
https://github.com/deepinsight/insightface.git
synced 2025-12-30 08:02:27 +00:00
74 lines
1.8 KiB
Python
74 lines
1.8 KiB
Python
import oneflow as flow
|
|
import oneflow.nn as nn
|
|
|
|
|
|
def make_static_grad_scaler():
|
|
return flow.amp.StaticGradScaler(flow.env.get_world_size())
|
|
|
|
|
|
def make_grad_scaler():
|
|
return flow.amp.GradScaler(
|
|
init_scale=2 ** 30, growth_factor=2.0, backoff_factor=0.5, growth_interval=2000,
|
|
)
|
|
|
|
|
|
def meter(self, mkey, *args):
|
|
assert mkey in self.m
|
|
self.m[mkey]["meter"].record(*args)
|
|
|
|
|
|
class TrainGraph(flow.nn.Graph):
|
|
def __init__(
|
|
self,
|
|
model,
|
|
cfg,
|
|
combine_margin,
|
|
cross_entropy,
|
|
data_loader,
|
|
optimizer,
|
|
lr_scheduler=None,
|
|
):
|
|
super().__init__()
|
|
|
|
if cfg.fp16:
|
|
self.config.enable_amp(True)
|
|
self.set_grad_scaler(make_grad_scaler())
|
|
elif cfg.scale_grad:
|
|
self.set_grad_scaler(make_static_grad_scaler())
|
|
|
|
self.config.allow_fuse_add_to_output(True)
|
|
self.config.allow_fuse_model_update_ops(True)
|
|
|
|
self.model = model
|
|
|
|
self.cross_entropy = cross_entropy
|
|
self.combine_margin = combine_margin
|
|
self.data_loader = data_loader
|
|
self.add_optimizer(optimizer, lr_sch=lr_scheduler)
|
|
|
|
def build(self):
|
|
image, label = self.data_loader()
|
|
|
|
image = image.to("cuda")
|
|
label = label.to("cuda")
|
|
|
|
logits, label = self.model(image, label)
|
|
logits = self.combine_margin(logits, label) * 64
|
|
loss = self.cross_entropy(logits, label)
|
|
|
|
loss.backward()
|
|
return loss
|
|
|
|
|
|
class EvalGraph(flow.nn.Graph):
|
|
def __init__(self, model, cfg):
|
|
super().__init__()
|
|
self.config.allow_fuse_add_to_output(True)
|
|
self.model = model
|
|
if cfg.fp16:
|
|
self.config.enable_amp(True)
|
|
|
|
def build(self, image):
|
|
logits = self.model(image)
|
|
return logits
|