Files
EasyFace/modelscope/hub/utils/caching.py
2023-03-02 11:17:26 +08:00

291 lines
9.7 KiB
Python

# Copyright (c) Alibaba, Inc. and its affiliates.
import hashlib
import os
import pickle
import tempfile
from shutil import move, rmtree
from modelscope.hub.constants import MODEL_META_FILE_NAME, MODEL_META_MODEL_ID
from modelscope.utils.logger import get_logger
logger = get_logger()
"""Implements caching functionality, used internally only
"""
class FileSystemCache(object):
KEY_FILE_NAME = '.msc'
"""Local file cache.
"""
def __init__(
self,
cache_root_location: str,
**kwargs,
):
"""Base file system cache interface.
Args:
cache_root_location (str): The root location to store files.
kwargs(dict): The keyword arguments.
"""
os.makedirs(cache_root_location, exist_ok=True)
self.cache_root_location = cache_root_location
self.load_cache()
def get_root_location(self):
return self.cache_root_location
def load_cache(self):
self.cached_files = []
cache_keys_file_path = os.path.join(self.cache_root_location,
FileSystemCache.KEY_FILE_NAME)
if os.path.exists(cache_keys_file_path):
with open(cache_keys_file_path, 'rb') as f:
self.cached_files = pickle.load(f)
def save_cached_files(self):
"""Save cache metadata."""
# save new meta to tmp and move to KEY_FILE_NAME
cache_keys_file_path = os.path.join(self.cache_root_location,
FileSystemCache.KEY_FILE_NAME)
# TODO: Sync file write
fd, fn = tempfile.mkstemp()
with open(fd, 'wb') as f:
pickle.dump(self.cached_files, f)
move(fn, cache_keys_file_path)
def get_file(self, key):
"""Check the key is in the cache, if exist, return the file, otherwise return None.
Args:
key(str): The cache key.
Raises:
None
"""
pass
def put_file(self, key, location):
"""Put file to the cache.
Args:
key (str): The cache key
location (str): Location of the file, we will move the file to cache.
Raises:
None
"""
pass
def remove_key(self, key):
"""Remove cache key in index, The file is removed manually
Args:
key (dict): The cache key.
"""
if key in self.cached_files:
self.cached_files.remove(key)
self.save_cached_files()
def exists(self, key):
for cache_file in self.cached_files:
if cache_file == key:
return True
return False
def clear_cache(self):
"""Remove all files and metadata from the cache
In the case of multiple cache locations, this clears only the last one,
which is assumed to be the read/write one.
"""
rmtree(self.cache_root_location)
self.load_cache()
def hash_name(self, key):
return hashlib.sha256(key.encode()).hexdigest()
class ModelFileSystemCache(FileSystemCache):
"""Local cache file layout
cache_root/owner/model_name/individual cached files and cache index file '.mcs'
Save only one version for each file.
"""
def __init__(self, cache_root, owner=None, name=None):
"""Put file to the cache
Args:
cache_root(`str`): The modelscope local cache root(default: ~/.cache/modelscope/)
owner(`str`): The model owner.
name('str'): The name of the model
Returns:
Raises:
None
<Tip>
model_id = {owner}/{name}
</Tip>
"""
if owner is None or name is None:
# get model meta from
super().__init__(os.path.join(cache_root))
self.load_model_meta()
else:
super().__init__(os.path.join(cache_root, owner, name))
self.model_meta = {MODEL_META_MODEL_ID: '%s/%s' % (owner, name)}
self.save_model_meta()
def load_model_meta(self):
meta_file_path = os.path.join(self.cache_root_location,
MODEL_META_FILE_NAME)
if os.path.exists(meta_file_path):
with open(meta_file_path, 'rb') as f:
self.model_meta = pickle.load(f)
else:
self.model_meta = {MODEL_META_MODEL_ID: 'unknown'}
def get_model_id(self):
return self.model_meta[MODEL_META_MODEL_ID]
def save_model_meta(self):
meta_file_path = os.path.join(self.cache_root_location,
MODEL_META_FILE_NAME)
with open(meta_file_path, 'wb') as f:
pickle.dump(self.model_meta, f)
def get_file_by_path(self, file_path):
"""Retrieve the cache if there is file match the path.
Args:
file_path (str): The file path in the model.
Returns:
path: the full path of the file.
"""
for cached_file in self.cached_files:
if file_path == cached_file['Path']:
cached_file_path = os.path.join(self.cache_root_location,
cached_file['Path'])
if os.path.exists(cached_file_path):
return cached_file_path
else:
self.remove_key(cached_file)
return None
def get_file_by_path_and_commit_id(self, file_path, commit_id):
"""Retrieve the cache if there is file match the path.
Args:
file_path (str): The file path in the model.
commit_id (str): The commit id of the file
Returns:
path: the full path of the file.
"""
for cached_file in self.cached_files:
if file_path == cached_file['Path'] and \
(cached_file['Revision'].startswith(commit_id) or commit_id.startswith(cached_file['Revision'])):
cached_file_path = os.path.join(self.cache_root_location,
cached_file['Path'])
if os.path.exists(cached_file_path):
return cached_file_path
else:
self.remove_key(cached_file)
return None
def get_file_by_info(self, model_file_info):
"""Check if exist cache file.
Args:
model_file_info (ModelFileInfo): The file information of the file.
Returns:
str: The file path.
"""
cache_key = self.__get_cache_key(model_file_info)
for cached_file in self.cached_files:
if cached_file == cache_key:
orig_path = os.path.join(self.cache_root_location,
cached_file['Path'])
if os.path.exists(orig_path):
return orig_path
else:
self.remove_key(cached_file)
break
return None
def __get_cache_key(self, model_file_info):
cache_key = {
'Path': model_file_info['Path'],
'Revision': model_file_info['Revision'], # commit id
}
return cache_key
def exists(self, model_file_info):
"""Check the file is cached or not.
Args:
model_file_info (CachedFileInfo): The cached file info
Returns:
bool: If exists return True otherwise False
"""
key = self.__get_cache_key(model_file_info)
is_exists = False
for cached_key in self.cached_files:
if cached_key['Path'] == key['Path'] and (
cached_key['Revision'].startswith(key['Revision'])
or key['Revision'].startswith(cached_key['Revision'])):
is_exists = True
break
file_path = os.path.join(self.cache_root_location,
model_file_info['Path'])
if is_exists:
if os.path.exists(file_path):
return True
else:
self.remove_key(
model_file_info) # someone may manual delete the file
return False
def remove_if_exists(self, model_file_info):
"""We in cache, remove it.
Args:
model_file_info (ModelFileInfo): The model file information from server.
"""
for cached_file in self.cached_files:
if cached_file['Path'] == model_file_info['Path']:
self.remove_key(cached_file)
file_path = os.path.join(self.cache_root_location,
cached_file['Path'])
if os.path.exists(file_path):
os.remove(file_path)
break
def put_file(self, model_file_info, model_file_location):
"""Put model on model_file_location to cache, the model first download to /tmp, and move to cache.
Args:
model_file_info (str): The file description returned by get_model_files.
model_file_location (str): The location of the temporary file.
Returns:
str: The location of the cached file.
"""
self.remove_if_exists(model_file_info) # backup old revision
cache_key = self.__get_cache_key(model_file_info)
cache_full_path = os.path.join(
self.cache_root_location,
cache_key['Path']) # Branch and Tag do not have same name.
cache_file_dir = os.path.dirname(cache_full_path)
if not os.path.exists(cache_file_dir):
os.makedirs(cache_file_dir, exist_ok=True)
# We can't make operation transaction
move(model_file_location, cache_full_path)
self.cached_files.append(cache_key)
self.save_cached_files()
return cache_full_path