Files
EasyFace/modelscope/msdatasets/utils/dataset_utils.py
2023-03-02 11:17:26 +08:00

223 lines
7.6 KiB
Python

# Copyright (c) Alibaba, Inc. and its affiliates.
import os
from collections import defaultdict
from typing import Optional, Union
from modelscope.hub.api import HubApi
from modelscope.utils.constant import DEFAULT_DATASET_REVISION, MetaDataFields
from modelscope.utils.logger import get_logger
logger = get_logger()
def format_dataset_structure(dataset_structure):
return {
k: v
for k, v in dataset_structure.items()
if (v.get('meta') or v.get('file'))
}
def get_target_dataset_structure(dataset_structure: dict,
subset_name: Optional[str] = None,
split: Optional[str] = None):
"""
Args:
dataset_structure (dict): Dataset Structure, like
{
"default":{
"train":{
"meta":"my_train.csv",
"file":"pictures.zip"
}
},
"subsetA":{
"test":{
"meta":"mytest.csv",
"file":"pictures.zip"
}
}
}
subset_name (str, optional): Defining the subset_name of the dataset.
split (str, optional): Which split of the data to load.
Returns:
target_subset_name (str): Name of the chosen subset.
target_dataset_structure (dict): Structure of the chosen split(s), like
{
"test":{
"meta":"mytest.csv",
"file":"pictures.zip"
}
}
"""
# verify dataset subset
if (subset_name and subset_name not in dataset_structure) or (
not subset_name and len(dataset_structure.keys()) > 1):
raise ValueError(
f'subset_name {subset_name} not found. Available: {dataset_structure.keys()}'
)
target_subset_name = subset_name
if not subset_name:
target_subset_name = next(iter(dataset_structure.keys()))
logger.info(
f'No subset_name specified, defaulting to the {target_subset_name}'
)
# verify dataset split
target_dataset_structure = format_dataset_structure(
dataset_structure[target_subset_name])
if split and split not in target_dataset_structure:
raise ValueError(
f'split {split} not found. Available: {target_dataset_structure.keys()}'
)
if split:
target_dataset_structure = {split: target_dataset_structure[split]}
return target_subset_name, target_dataset_structure
def list_dataset_objects(hub_api: HubApi, max_limit: int, is_recursive: bool,
dataset_name: str, namespace: str,
version: str) -> list:
"""
List all objects for specific dataset.
Args:
hub_api (class HubApi): HubApi instance.
max_limit (int): Max number of objects.
is_recursive (bool): Whether to list objects recursively.
dataset_name (str): Dataset name.
namespace (str): Namespace.
version (str): Dataset version.
Returns:
res (list): List of objects, i.e., ['train/images/001.png', 'train/images/002.png', 'val/images/001.png', ...]
"""
res = []
objects = hub_api.list_oss_dataset_objects(dataset_name=dataset_name,
namespace=namespace,
max_limit=max_limit,
is_recursive=is_recursive,
is_filter_dir=True,
revision=version)
for item in objects:
object_key = item.get('Key')
if not object_key:
continue
res.append(object_key)
return res
def contains_dir(file_map) -> bool:
"""
To check whether input contains at least one directory.
Args:
file_map (dict): Structure of data files. e.g., {'train': 'train.zip', 'validation': 'val.zip'}
Returns:
True if input contains at least one directory, False otherwise.
"""
res = False
for k, v in file_map.items():
if isinstance(v, str) and not v.endswith('.zip'):
res = True
break
return res
def get_subdir_hash_from_split(split: Union[str, list], version: str) -> str:
if isinstance(split, str):
split = [split]
return os.path.join(version, '_'.join(split))
def get_split_list(split: Union[str, list]) -> list:
""" Unify the split to list-format. """
if isinstance(split, str):
return [split]
elif isinstance(split, list):
return split
else:
raise f'Expected format of split: str or list, but got {type(split)}.'
def get_split_objects_map(file_map, objects):
"""
Get the map between dataset split and oss objects.
Args:
file_map (dict): Structure of data files. e.g., {'train': 'train', 'validation': 'val'}, both of train and val
are dirs.
objects (list): List of oss objects. e.g., ['train/001/1_123.png', 'train/001/1_124.png', 'val/003/3_38.png']
Returns:
A map of split-objects. e.g., {'train': ['train/001/1_123.png', 'train/001/1_124.png'],
'validation':['val/003/3_38.png']}
"""
res = {}
for k, v in file_map.items():
res[k] = []
for obj_key in objects:
for k, v in file_map.items():
if obj_key.startswith(v + '/'):
res[k].append(obj_key)
return res
def get_dataset_files(subset_split_into: dict,
dataset_name: str,
namespace: str,
revision: Optional[str] = DEFAULT_DATASET_REVISION):
"""
Return:
meta_map: Structure of meta files (.csv), the meta file name will be replaced by url, like
{
"test": "https://xxx/mytest.csv"
}
file_map: Structure of data files (.zip), like
{
"test": "pictures.zip"
}
"""
meta_map = defaultdict(dict)
file_map = defaultdict(dict)
args_map = defaultdict(dict)
modelscope_api = HubApi()
for split, info in subset_split_into.items():
meta_map[split] = modelscope_api.get_dataset_file_url(
info.get('meta', ''), dataset_name, namespace, revision)
if info.get('file'):
file_map[split] = info['file']
args_map[split] = info.get('args')
objects = []
# If `big_data` is true, then fetch objects from meta-csv file directly.
for split, args_dict in args_map.items():
if args_dict and args_dict.get(MetaDataFields.ARGS_BIG_DATA):
meta_csv_file_url = meta_map[split]
_, script_content = modelscope_api.fetch_single_csv_script(
meta_csv_file_url)
if not script_content:
raise 'Meta-csv file cannot be empty when meta-args `big_data` is true.'
for item in script_content:
if not item:
continue
item = item.strip().split(',')[0]
if item:
objects.append(item)
file_map[split] = objects
# More general but low-efficiency.
if not objects:
objects = list_dataset_objects(hub_api=modelscope_api,
max_limit=-1,
is_recursive=True,
dataset_name=dataset_name,
namespace=namespace,
version=revision)
if contains_dir(file_map):
file_map = get_split_objects_map(file_map, objects)
return meta_map, file_map, args_map