Files
EasyFace/modelscope/models/base/base_head.py
2023-03-02 11:17:26 +08:00

40 lines
1.1 KiB
Python

# Copyright (c) Alibaba, Inc. and its affiliates.
from abc import ABC, abstractmethod
from typing import Any, Dict, Union
from modelscope.models.base.base_model import Model
from modelscope.utils.config import ConfigDict
from modelscope.utils.logger import get_logger
logger = get_logger()
Tensor = Union['torch.Tensor', 'tf.Tensor']
Input = Union[Dict[str, Tensor], Model]
class Head(ABC):
"""The head base class is for the tasks head method definition
"""
def __init__(self, **kwargs):
self.config = ConfigDict(kwargs)
@abstractmethod
def forward(self, *args, **kwargs) -> Dict[str, Any]:
"""
This method will use the output from backbone model to do any
downstream tasks. Receive The output from backbone model.
Returns (Dict[str, Any]): The output from downstream task.
"""
pass
@abstractmethod
def compute_loss(self, *args, **kwargs) -> Dict[str, Any]:
"""
compute loss for head during the finetuning.
Returns (Dict[str, Any]): The loss dict
"""
pass