diff --git a/src/psyclone/parse/module_manager.py b/src/psyclone/parse/module_manager.py index 1d544b9432..485040f82d 100644 --- a/src/psyclone/parse/module_manager.py +++ b/src/psyclone/parse/module_manager.py @@ -32,6 +32,7 @@ # POSSIBILITY OF SUCH DAMAGE. # ----------------------------------------------------------------------------- # Author J. Henrichs, Bureau of Meteorology +# Modifications: M. Naylor, University of Cambridge, UK '''This module contains a singleton class that manages information about which module is contained in which file (including full location). ''' @@ -141,6 +142,11 @@ def __init__(self): self._module_pattern = re.compile(r"^\s*module\s+([a-z]\S*)\s*$", flags=re.IGNORECASE | re.MULTILINE) + # Files with an extension from this set will be considered + # when searching for Fortran files + self._fortran_file_exts = {".F90", ".f90", ".X90", ".x90", + ".F95", ".f95", ".F03", ".f03"} + @property def cache_active(self) -> bool: ''' @@ -210,6 +216,29 @@ def resolve_indirect_imports(self, value: Union[bool, Iterable[str]]): f"str, but found an item of {type(x)}") self._resolve_indirect_imports = value + @property + def fortran_file_exts(self) -> set[str]: + ''' + :returns: the set of file extensions that are considered when + searching for Fortran files. + ''' + return self._fortran_file_exts + + @fortran_file_exts.setter + def fortran_file_exts(self, exts: set[str]): + ''' + :param exts: the set of file extensions that are considered when + searching for Fortran files. + + :raises TypeError: if the provided value is not a set. + ''' + if not isinstance(exts, set): + raise TypeError( + f"'fortran_file_exts' must be a set, but found " + f"{type(exts).__name__}" + ) + self._fortran_file_exts = exts + # ------------------------------------------------------------------------ def add_search_path(self, directories: Union[str, Path, @@ -254,7 +283,7 @@ def add_search_path(self, # ------------------------------------------------------------------------ def _add_all_files_from_dir(self, directory: str) -> list[FileInfo]: '''This function creates (and caches) FileInfo objects for all files - with an extension of (F/f/X/x)90 in the given directory that have + with a Fortran file extension in the given directory that have not previously been visited. The new FileInfo objects are returned. :param directory: the directory containing Fortran files @@ -269,7 +298,7 @@ def _add_all_files_from_dir(self, directory: str) -> list[FileInfo]: for entry in all_entries: _, ext = os.path.splitext(entry.name) if (not entry.is_file() or - ext not in [".F90", ".f90", ".X90", ".x90"]): + ext not in self._fortran_file_exts): continue full_path = os.path.join(directory, entry.name) if full_path in self._visited_files: @@ -321,7 +350,7 @@ def _find_module_in_files( self._modules[name] = mod_info # A file that has been (or does not require) # preprocessing always takes precedence. - if finfo.filename.endswith(".f90"): + if self._doesnt_need_preprocessing(finfo.filename): return mod_info return mod_info @@ -490,6 +519,19 @@ def all_file_infos(self) -> list[FileInfo]: """ return list(self._filepath_to_file_info.values()) + def _doesnt_need_preprocessing(self, filename: str) -> bool: + """Returns True if the file with the given filename + doesn't need preprocessing.""" + # The current method is just to check that the file extension + # is a lower-case Fortran file extension and doesn't begin + # with '.x'. (The latter condition is present to preserve + # previous behaviour, although it's unclear if that behavior + # is indeed desired.) + base, ext = os.path.splitext(filename) + return (ext in self._fortran_file_exts and + ext.islower() and + not ext.startswith(".x")) + def get_module_info(self, module_name: str) -> Optional[ModuleInfo]: """This function returns the ModuleInfo for the specified module. @@ -509,16 +551,15 @@ def get_module_info(self, module_name: str) -> Optional[ModuleInfo]: return None # First check if we have already seen this module. We only end the - # search early if the file we've found does not require pre-processing - # (i.e. has a .f90 suffix). + # search early if the file we've found does not require pre-processing. mod_info = self._modules.get(mod_lower, None) - if mod_info and mod_info.filename.endswith(".f90"): + if mod_info and self._doesnt_need_preprocessing(mod_info.filename): return mod_info old_mod_info = mod_info # Are any of the files that we've already seen a good match? mod_info = self._find_module_in_files(mod_lower, self._visited_files.values()) - if mod_info and mod_info.filename.endswith(".f90"): + if mod_info and self._doesnt_need_preprocessing(mod_info.filename): return mod_info old_mod_info = mod_info diff --git a/src/psyclone/tests/parse/module_manager_test.py b/src/psyclone/tests/parse/module_manager_test.py index 57a1f2d8af..44d1bab7b4 100644 --- a/src/psyclone/tests/parse/module_manager_test.py +++ b/src/psyclone/tests/parse/module_manager_test.py @@ -547,3 +547,30 @@ def test_mod_manager_load_all_module_trigger_error_file_read_twice() -> None: mod_man.load_all_module_infos(error_if_file_already_processed=True) assert "File 't_mod.f90' already processed" in str(einfo.value) + + +@pytest.mark.usefixtures("clear_module_manager_instance") +def test_mod_manager_fortran_file_exts() -> None: + '''Tests functionality for modifying managing Fortran file etensions.''' + mod_man = ModuleManager.get() + + # Check that the default extensions include '.f90' and '.F90' + assert '.f90' in mod_man.fortran_file_exts + assert '.F90' in mod_man.fortran_file_exts + + # Check that we can add and remove a new Fortran file extension + assert '.foo' not in mod_man.fortran_file_exts + mod_man.fortran_file_exts.add(".foo") + assert '.foo' in mod_man.fortran_file_exts + assert '.f90' in mod_man.fortran_file_exts + mod_man.fortran_file_exts.remove(".foo") + assert '.foo' not in mod_man.fortran_file_exts + assert '.f90' in mod_man.fortran_file_exts + mod_man.fortran_file_exts = {".f95"} + assert mod_man.fortran_file_exts == {".f95"} + + # Check that type errors are spotted + with pytest.raises(TypeError) as err: + mod_man.fortran_file_exts = True + assert ("'fortran_file_exts' must be a set, but found bool" + in str(err.value))