Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 10 additions & 11 deletions doc/user_guide/user_scripts.rst
Original file line number Diff line number Diff line change
Expand Up @@ -44,23 +44,22 @@ PSyclone User Scripts
The standard way to transform a codebase using psyclone is through the
:ref:`psyclone_command` tool, which has an optional ``-s <SCRIPT_NAME>``
flag that allows users to specify a transformation user script to
programmatically modify the input code::
programmatically modify the input code using a relative or absolute
path::

> psyclone -s optimise.py input_source.f90

In this case, the current directory is prepended to the Python search path
``PYTHONPATH`` which will then be used to try to find the script file. Thus,
the search begins in the current directory and continues over any pre-existing
directories in the search path, failing if the file cannot be found.

Alternatively, script files may be specified with a path. In this case
the file must exist in the specified location. This location is then added to
the Python search path ``PYTHONPATH`` as before. For example::

> psyclone -s ./optimise.py input_source.f90
> psyclone -s ../scripts/optimise.py input_source.f90
> psyclone -s /home/me/PSyclone/scripts/optimise.py input_source.f90

PSyclone will not take ``PYTHONPATH`` into account when importing the script,
so the user must ensure that the script is found: For this, it must either be
in the current working directory (if no path is specified at all), or be
specified with a valid (relative or absolute) path. PSyclone will add the
directory of the script to its Python's ``sys.path``, meaning the PSyclone
script can import any helper script in the same directory without additional
setup.

A valid PSyclone user script file must contain a ``trans`` function which accepts
a :ref:`PSyIR node<psyir-ug>` representing the root of the psy-layer
code (as a FileContainer):
Expand Down
28 changes: 20 additions & 8 deletions src/psyclone/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,13 +47,14 @@
'''

import argparse
import importlib
import logging
import os
import pathlib
import shutil
import sys
import traceback
import importlib
import shutil
from typing import Callable, Iterable, List, Optional, Tuple, Union
import logging

from fparser.api import get_reader
from fparser.two import Fortran2003
Expand Down Expand Up @@ -149,12 +150,18 @@ def load_script(
raise GenerationError(
f"generator: expected the script file '{filename}' to have "
f"the '.py' extension")
# prepend file path - if none, the empty string equates to the current
# working directory - to the system path to guarantee we find the user
# provided module instead of a similarly named module that might
# already exist elsewhere in the system path

# Add the script directory to sys.path, so scripts can easily import
# helper scripts in the same directory (this step is not needed to
# import the script itself, but it maintains backwards compatibility).
sys.path.insert(0, filepath)
recipe_module = importlib.import_module(module_name)

# This will import the module, but not make it part of the
# system list of all modules, i.e. it can be used elsewhere.
script_path = pathlib.Path(script_name)
spec = importlib.util.spec_from_file_location(module_name, script_path)
recipe_module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(recipe_module)

if hasattr(recipe_module, "FILES_TO_SKIP"):
files_to_skip = recipe_module.FILES_TO_SKIP
Expand Down Expand Up @@ -287,6 +294,7 @@ def generate(filename: str,
# Apply provided recipe to PSyIR
recipe, _, _ = load_script(script_name)
recipe(psy.container.root)
del sys.path[0]
alg_gen = None

elif api in GOCEAN_API_NAMES or (api in LFRIC_API_NAMES and LFRIC_TESTING):
Expand Down Expand Up @@ -340,6 +348,7 @@ def generate(filename: str,
is_optional=True)
if recipe:
recipe(psyir)
del sys.path[0]

# For each kernel called from the algorithm layer
kernels = {}
Expand Down Expand Up @@ -428,6 +437,7 @@ def generate(filename: str,
# Call the optimisation script for psy-layer optimisations
recipe, _, _ = load_script(script_name)
recipe(psy.container.root)
del sys.path[0]

# TODO issue #1618 remove Alg class and tests from PSyclone
if api in LFRIC_API_NAMES and not LFRIC_TESTING:
Expand Down Expand Up @@ -979,3 +989,5 @@ def code_transformation_mode(input_file, recipe_file, output_file,
else:
print(f"File '{input_file}' skipped because it is listed in "
"FILES_TO_SKIP.", file=sys.stdout)

del sys.path[0]
93 changes: 37 additions & 56 deletions src/psyclone/tests/generator_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@
import re
import shutil
import stat
from sys import modules
from typing import Optional
import pytest

Expand Down Expand Up @@ -84,37 +83,18 @@
"test_files", "gocean1p0")


@pytest.fixture(name="script_factory", scope="function")
def create_script_factor(tmpdir):
''' Fixture that creates a psyclone optimisation script given the string
def script_factory(tmpdir, code: str) -> Path:
"""
Function that creates a psyclone optimisation script given the string
representing the body of the script:

script_path = script_factory("def trans(psyir):\n pass")

It has a 'function' scope and a tear down section because using a script
imports the file and this is kept in the python interpreter state, so we
delete it for future tests.

'''
"""
tmpfile = os.path.join(tmpdir, "test_script.py")

def populate_script(string):
with open(tmpfile, 'w+', encoding="utf8") as script:
script.write(string)
return tmpfile

yield populate_script
# Tear down section executed after each test that uses the fixture
# If the created script was used, then its module (file) was imported
# into the interpreter runtime, we need to make sure it is deleted
modname = "test_script"
if modname in modules:
del modules[modname]
for mod in modules.values():
try:
delattr(mod, modname)
except AttributeError:
pass
with open(tmpfile, 'w+', encoding="utf8") as script:
script.write(code)
return tmpfile


def test_script_file_not_found():
Expand Down Expand Up @@ -165,14 +145,14 @@ def test_script_file_wrong_extension():
"extension" in str(error.value))


def test_script_invalid_content(script_factory):
def test_script_invalid_content(tmpdir):
'''Checks that load_script() in generator.py raises the expected
exception when a script file does not contain valid python. This
test uses the generate() function to call load_script as this is
a simple way to create its required arguments.

'''
error_syntax = script_factory("""
error_syntax = script_factory(tmpdir, """
this is invalid python
""")
with pytest.raises(Exception) as err:
Expand All @@ -181,7 +161,7 @@ def test_script_invalid_content(script_factory):
api="lfric", script_name=error_syntax)
assert "invalid syntax (test_script.py, line 2)" in str(err.value)

error_import = script_factory("""
error_import = script_factory(tmpdir, """
import non_existent
""")
with pytest.raises(Exception) as err:
Expand All @@ -191,15 +171,15 @@ def test_script_invalid_content(script_factory):
assert "No module named 'non_existent'" in str(err.value)


def test_script_invalid_content_runtime(script_factory):
def test_script_invalid_content_runtime(tmpdir):
'''Checks that load_script() function in generator.py raises the
expected exception when a script file contains valid python
syntactically but produces a runtime exception. This test uses the
generate() function to call load_script as this is a simple way
to create its required arguments.

'''
runtime_error = script_factory("""
runtime_error = script_factory(tmpdir, """
def trans(psyir):
# this will produce a runtime error as b has not been assigned
psyir = b
Expand All @@ -211,15 +191,15 @@ def trans(psyir):
assert "name 'b' is not defined" in str(error.value)


def test_script_no_trans(script_factory):
def test_script_no_trans(tmpdir):
'''Checks that load_script() function in generator.py raises the
expected exception when a script file does not contain a trans()
function. This test uses the generate() function to call
load_script as this is a simple way to create its required
arguments.

'''
no_trans_script = script_factory("""
no_trans_script = script_factory(tmpdir, """
def nottrans(psyir):
pass

Expand All @@ -235,7 +215,7 @@ def tran():
in str(error.value))


def test_script_no_trans_alg(capsys, script_factory):
def test_script_no_trans_alg(capsys, tmpdir):
'''Checks that load_script() function in generator.py does not raise
an exception when a script file does not contain a trans_alg()
function as these are optional. At the moment this function is
Expand All @@ -244,7 +224,7 @@ def test_script_no_trans_alg(capsys, script_factory):
its required arguments.

'''
no_alg_script = script_factory("def trans(psyir):\n pass")
no_alg_script = script_factory(tmpdir, "def trans(psyir):\n pass")
_, _ = generate(
os.path.join(BASE_PATH, "gocean1p0", "single_invoke.f90"),
api="gocean", script_name=no_alg_script)
Expand All @@ -254,7 +234,7 @@ def test_script_no_trans_alg(capsys, script_factory):
assert "Deprecation warning:" not in captured.err


def test_script_with_legacy_trans_signature(capsys, script_factory):
def test_script_with_legacy_trans_signature(capsys, tmpdir):
'''Checks that load_script() function in generator.py does not raise
an exception when a script file uses the legacy trans signature.

Expand All @@ -264,7 +244,7 @@ def test_script_with_legacy_trans_signature(capsys, script_factory):
This will eventually be deprecated.

'''
legacy_script = script_factory("""
legacy_script = script_factory(tmpdir, """
def trans(psy):
# The following are backwards-compatible expressions with legacy scripts
_ = psy.invokes.invoke_list
Expand Down Expand Up @@ -440,13 +420,13 @@ def test_no_script_gocean():
assert "module psy_single_invoke_test" in str(psy)


def test_script_gocean(script_factory):
def test_script_gocean(tmpdir):
'''Test that the generate function in generator.py returns
successfully if a script (containing both trans_alg() and trans()
functions) is specified.

'''
alg_script = script_factory("""
alg_script = script_factory(tmpdir, """
def trans_alg(psyir):
pass

Expand Down Expand Up @@ -496,12 +476,12 @@ def _broken(_1, _2):
assert "Failed to create PSyIR from file '" in err


def test_script_attr_error(script_factory):
def test_script_attr_error(tmpdir):
'''Checks that generator.py raises an appropriate error when a script
file contains a trans() function which raises an attribute error.

'''
error_script = script_factory("""
error_script = script_factory(tmpdir, """
from psyclone.psyGen import Loop
from psyclone.transformations import ColourTrans

Expand All @@ -520,12 +500,12 @@ def trans(psyir):
assert 'object has no attribute' in str(excinfo.value)


def test_script_null_trans(script_factory):
def test_script_null_trans(tmpdir):
'''Checks that generator.py works correctly when the trans() function
in a valid script file does no transformations.

'''
empty_script = script_factory("def trans(psyir):\n pass")
empty_script = script_factory(tmpdir, "def trans(psyir):\n pass")
alg1, psy1 = generate(os.path.join(BASE_PATH, "lfric",
"1_single_invoke.f90"),
api="lfric")
Expand All @@ -540,40 +520,41 @@ def test_script_null_trans(script_factory):
'\n'.join(str(psy2).split('\n')[1:])


def test_script_null_trans_relative(script_factory):
def test_script_null_trans_relative(monkeypatch, tmpdir):
'''Checks that generator.py works correctly when the trans() function
in a valid script file does no transformations (it simply passes
input to output). In this case the valid script file contains no
path and must therefore be found via the PYTHOPATH path list.
path, but is invoked from the script's directory, as a relative
path.

'''
alg1, psy1 = generate(os.path.join(BASE_PATH, "lfric",
"1_single_invoke.f90"),
api="lfric")
empty_script = script_factory("def trans(psyir):\n pass")
empty_script = script_factory(tmpdir, "def trans(psyir):\n pass")
basename = os.path.basename(empty_script)
path = os.path.dirname(empty_script)
# Set the script directory in the PYTHONPATH
os.sys.path.append(path)

# Change into the script's directory so it's found as a relative import.
monkeypatch.chdir(tmpdir)

alg2, psy2 = generate(os.path.join(BASE_PATH, "lfric",
"1_single_invoke.f90"),
api="lfric", script_name=basename)
# Remove the path from PYTHONPATH
os.sys.path.pop()

# we need to remove the first line before comparing output as
# this line is an instance specific header
assert '\n'.join(str(alg1).split('\n')[1:]) == \
'\n'.join(str(alg2).split('\n')[1:])
assert str(psy1) == str(psy2)


def test_script_trans_lfric(script_factory):
def test_script_trans_lfric(tmpdir):
'''Checks that generator.py works correctly when a transformation is
provided as a script, i.e. it applies the transformations
correctly.

'''
fuse_loop_script = script_factory("""
fuse_loop_script = script_factory(tmpdir, """
from psyclone.domain.lfric.transformations import LFRicLoopFuseTrans
def trans(psyir):
module = psyir.children[0]
Expand Down Expand Up @@ -1896,7 +1877,7 @@ def test_no_script_lfric_new(monkeypatch):
assert "use _psyclone_builtins" not in alg


def test_script_lfric_new(monkeypatch, script_factory):
def test_script_lfric_new(monkeypatch, tmpdir):
'''Test that the generate function in generator.py returns
successfully if a script (containing both trans_alg() and trans()
functions) is specified. This test uses the new PSyIR approach to
Expand All @@ -1905,7 +1886,7 @@ def test_script_lfric_new(monkeypatch, script_factory):
monkeypatching.

'''
alg_script = script_factory("""
alg_script = script_factory(tmpdir, """
def trans_alg(psyir):
pass

Expand Down
Loading