Skip to content

Commit

Permalink
Use mock import instead of dummy parent package to implement relative…
Browse files Browse the repository at this point in the history
… import
  • Loading branch information
mzr1996 committed Jul 18, 2023
1 parent 1793ea4 commit 5646bf0
Show file tree
Hide file tree
Showing 4 changed files with 99 additions and 133 deletions.
197 changes: 83 additions & 114 deletions mmengine/config/new_config.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
# Copyright (c) OpenMMLab. All rights reserved.
import builtins
import importlib
import inspect
import os
import platform
import sys
from importlib.abc import Loader, MetaPathFinder
from importlib.machinery import PathFinder
from importlib.util import spec_from_loader
from pathlib import Path
from types import BuiltinFunctionType, FunctionType, ModuleType
from typing import Optional, Tuple, Union
Expand All @@ -16,6 +16,7 @@
from .lazy import LazyImportContext, LazyObject

RESERVED_KEYS = ['filename', 'text', 'pretty_text']
_CFG_UID = 0

if platform.system() == 'Windows':
import regex as re
Expand Down Expand Up @@ -115,8 +116,7 @@ class ConfigV2(Config):
.. _config tutorial: https://mmengine.readthedocs.io/en/latest/advanced_tutorials/config.html
""" # noqa: E501
_max_parent_depth = 4
_parent_pkg = '_cfg_parent'
_pkg_prefix = '_mmengine_cfg'

def __init__(self,
cfg_dict: dict = None,
Expand Down Expand Up @@ -161,7 +161,7 @@ def _sanity_check(cfg):
for v in cfg:
ConfigV2._sanity_check(v)
elif isinstance(cfg, (type, FunctionType)):
if (ConfigV2._parent_pkg in cfg.__module__
if (ConfigV2._pkg_prefix in cfg.__module__
or '__main__' in cfg.__module__):
msg = ('You cannot use temporary functions '
'as the value of a field.\n\n')
Expand Down Expand Up @@ -211,22 +211,29 @@ def fromfile(filename: Union[str, Path],
format_python_code=format_python_code)
finally:
ConfigDict.lazy = False
global _CFG_UID
_CFG_UID = 0
for mod in list(sys.modules):
if mod.startswith(ConfigV2._pkg_prefix):
del sys.modules[mod]

return cfg

@staticmethod
def _get_config_module(filename: Union[str, Path], level=0):
def _get_config_module(filename: Union[str, Path]):
file = Path(filename).absolute()
module_name = re.sub(r'\W|^(?=\d)', '_', file.stem)
parent_pkg = ConfigV2._parent_pkg + str(level)
fullname = '.'.join([parent_pkg] * ConfigV2._max_parent_depth +
[module_name])
global _CFG_UID
# Build a unique module name to avoid conflict.
fullname = f'{ConfigV2._pkg_prefix}{_CFG_UID}_{module_name}'
_CFG_UID += 1

# import config file as a module
with LazyImportContext():
spec = importlib.util.spec_from_file_location(fullname, file)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
sys.modules[fullname] = module

return module

Expand Down Expand Up @@ -338,14 +345,16 @@ def _format_basic_types(input_):

return text

def __getstate__(self) -> Tuple[dict, Optional[str], Optional[str]]:
return (self._cfg_dict, self._filename, self._text)
def __getstate__(self) -> Tuple[dict, Optional[str], Optional[str], bool]:
return (self._cfg_dict, self._filename, self._text,
self._format_python_code)

def __setstate__(self, state: Tuple[dict, Optional[str], Optional[str]]):
_cfg_dict, _filename, _text = state
super(Config, self).__setattr__('_cfg_dict', _cfg_dict)
super(Config, self).__setattr__('_filename', _filename)
super(Config, self).__setattr__('_text', _text)
def __setstate__(self, state: Tuple[dict, Optional[str], Optional[str],
bool]):
super(Config, self).__setattr__('_cfg_dict', state[0])
super(Config, self).__setattr__('_filename', state[1])
super(Config, self).__setattr__('_text', state[2])
super(Config, self).__setattr__('_format_python_code', state[3])

def _to_lazy_dict(self, keep_imported: bool = False) -> dict:
"""Convert config object to dictionary and filter the imported
Expand Down Expand Up @@ -383,118 +392,78 @@ def lazy2string(cfg_dict):
return lazy2string(_cfg_dict)


class BaseConfigLoader(Loader):

def __init__(self, filepath, level) -> None:
self.filepath = filepath
self.level = level

def create_module(self, spec):
file = self.filepath
return ConfigV2._get_config_module(file, level=self.level)

def exec_module(self, module):
for k in dir(module):
module.__dict__[k] = ConfigV2._dict_to_config_dict_lazy(
getattr(module, k))


class ParentFolderLoader(Loader):

@staticmethod
def create_module(spec):
return ModuleType(spec.name)

@staticmethod
def exec_module(module):
pass


class BaseImportContext(MetaPathFinder):

def find_spec(self, fullname, path=None, target=None):
"""Try to find a spec for 'fullname' on sys.path or 'path'.
The search is based on sys.path_hooks and sys.path_importer_cache.
"""
parent_pkg = ConfigV2._parent_pkg + str(self.level)
names = fullname.split('.')

if names[-1] == parent_pkg:
self.base_modules.append(fullname)
# Create parent package
return spec_from_loader(
fullname, loader=ParentFolderLoader, is_package=True)
elif names[0] == parent_pkg:
self.base_modules.append(fullname)
# relative imported base package
filepath = self.root_path
for name in names:
if name == parent_pkg:
# Use parent to remove `..` at the end of the root path
filepath = filepath.parent
else:
filepath = filepath / name
if filepath.is_dir():
# If a dir, create a package.
return spec_from_loader(
fullname, loader=ParentFolderLoader, is_package=True)

pypath = filepath.with_suffix('.py')

if not pypath.exists():
raise ImportError(f'Not found base path {filepath.resolve()}')
return importlib.util.spec_from_loader(
fullname, BaseConfigLoader(pypath, self.level + 1))
else:
# Absolute import
pkg = PathFinder.find_spec(names[0])
if pkg and pkg.submodule_search_locations:
self.base_modules.append(fullname)
path = Path(pkg.submodule_search_locations[0])
for name in names[1:]:
path = path / name
if path.is_dir():
return spec_from_loader(
fullname, loader=ParentFolderLoader, is_package=True)
pypath = path.with_suffix('.py')
if not pypath.exists():
raise ImportError(f'Not found base path {path.resolve()}')
return importlib.util.spec_from_loader(
fullname, BaseConfigLoader(pypath, self.level + 1))
return None
class BaseImportContext():

def __enter__(self):
# call from which file
stack = inspect.stack()[1]
file = inspect.getfile(stack[0])
folder = Path(file).parent
self.root_path = folder.joinpath(*(['..'] *
ConfigV2._max_parent_depth))

self.base_modules = []
self.level = len(
[p for p in sys.meta_path if isinstance(p, BaseImportContext)])

# Disable enabled lazy loader during parsing base
self.lazy_importers = []
for p in sys.meta_path:
if isinstance(p, LazyImportContext) and p.enable:
self.lazy_importers.append(p)
p.enable = False

index = sys.meta_path.index(importlib.machinery.FrozenImporter)
sys.meta_path.insert(index + 1, self)
old_import = builtins.__import__

def new_import(name, globals=None, locals=None, fromlist=(), level=0):
cur_file = None

# Try to import the base config source file
if level != 0 and globals is not None:
# For relative import path
if '__file__' in globals:
loc = Path(globals['__file__']).parent
else:
loc = Path(os.getcwd())
cur_file = self.find_relative_file(loc, name, level - 1)
if not cur_file.exists():
raise ImportError(f'Cannot import name "{name}" from '
f'{loc}: {cur_file} does not exist.')
elif level == 0:
# For absolute import path
pkg, _, mod = name.partition('.')
pkg = PathFinder.find_spec(pkg)
if mod and pkg.submodule_search_locations:
loc = Path(pkg.submodule_search_locations[0])
cur_file = self.find_relative_file(loc, mod)
if not cur_file.exists():
raise ImportError(f'Cannot import name "{name}": '
f'{cur_file} does not exist.')

# Recover the original import during handle the base config file.
builtins.__import__ = old_import

if cur_file is not None:
mod = ConfigV2._get_config_module(cur_file)

for k in dir(mod):
mod.__dict__[k] = ConfigV2._dict_to_config_dict_lazy(
getattr(mod, k))
else:
mod = old_import(
name, globals, locals, fromlist=fromlist, level=level)

builtins.__import__ = new_import

return mod

self.old_import = old_import
builtins.__import__ = new_import

def __exit__(self, exc_type, exc_val, exc_tb):
sys.meta_path.remove(self)
for name in self.base_modules:
sys.modules.pop(name, None)
builtins.__import__ = self.old_import
for p in self.lazy_importers:
p.enable = True

def __repr__(self):
return f'<BaseImportContext (level={self.level})>'
@staticmethod
def find_relative_file(loc: Path, relative_import_path, level=0):
if level > 0:
loc = loc.parents[level - 1]
names = relative_import_path.lstrip('.').split('.')

for name in names:
loc = loc / name

return loc.with_suffix('.py')


read_base = BaseImportContext
20 changes: 11 additions & 9 deletions mmengine/config/old_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -666,14 +666,16 @@ def env_variables(self) -> dict:
"""get used environment variables."""
return self._env_variables

def __getstate__(self) -> Tuple[dict, Optional[str], Optional[str], dict]:
return (self._cfg_dict, self._filename, self._text,
self._env_variables)
def __getstate__(
self) -> Tuple[dict, Optional[str], Optional[str], dict, bool]:
state = (self._cfg_dict, self._filename, self._text,
self._env_variables, self._format_python_code)
return state

def __setstate__(self, state: Tuple[dict, Optional[str], Optional[str],
dict]):
_cfg_dict, _filename, _text, _env_variables = state
super(Config, self).__setattr__('_cfg_dict', _cfg_dict)
super(Config, self).__setattr__('_filename', _filename)
super(Config, self).__setattr__('_text', _text)
super(Config, self).__setattr__('_text', _env_variables)
dict, bool]):
super(Config, self).__setattr__('_cfg_dict', state[0])
super(Config, self).__setattr__('_filename', state[1])
super(Config, self).__setattr__('_text', state[2])
super(Config, self).__setattr__('_env_variables', state[3])
super(Config, self).__setattr__('_format_python_code', state[4])
2 changes: 1 addition & 1 deletion mmengine/registry/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
from rich.console import Console
from rich.table import Table

from mmengine.config.utils import MODULE2PACKAGE
from mmengine.config.lazy import LazyObject
from mmengine.config.utils import MODULE2PACKAGE
from mmengine.utils import is_seq_of
from .default_scope import DefaultScope

Expand Down
13 changes: 4 additions & 9 deletions tests/test_config/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -986,8 +986,9 @@ def test_lazy_import(self, tmp_path):
cfg_dict = cfg.to_dict()
assert (cfg_dict['train_dataloader']['dataset']['type'] ==
'<mmengine.testing.runner_test_case.ToyDataset>')
assert (
cfg_dict['custom_hooks'][0]['type'] == '<mmengine.hooks.EMAHook>')
assert (cfg_dict['custom_hooks'][0]['type']
in ('<mmengine.hooks.EMAHook>',
'<mmengine.hooks.ema_hook.EMAHook>'))
# Dumped config
dumped_cfg_path = tmp_path / 'test_dump_lazy.py'
cfg.dump(dumped_cfg_path)
Expand Down Expand Up @@ -1060,12 +1061,6 @@ def _compare_dict(a, b):
osp.join(self.data_path,
'config/lazy_module_config/error_mix_using1.py'))

# Force to import in non-lazy-import mode
Config.fromfile(
osp.join(self.data_path,
'config/lazy_module_config/error_mix_using1.py'),
lazy_import=False)

# current lazy-import config, base text config
with pytest.raises(AttributeError, match='item2'):
Config.fromfile(
Expand All @@ -1088,7 +1083,7 @@ def _compare_dict(a, b):
dumped_cfg = Config.fromfile(dumped_cfg_path)

assert set(dumped_cfg.keys()) == {
'path', 'name', 'suffix', 'chained', 'existed', 'cfgname'
'path', 'name', 'suffix', 'chained', 'existed', 'cfgname', 'ex'
}
assert dumped_cfg.to_dict() == cfg.to_dict()

Expand Down

0 comments on commit 5646bf0

Please sign in to comment.