Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add _operation variable #1733

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 52 additions & 4 deletions copier/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@
import subprocess
import sys
from contextlib import suppress
from contextvars import ContextVar
from dataclasses import asdict, field, replace
from filecmp import dircmp
from functools import cached_property, partial
from functools import cached_property, partial, wraps
from itertools import chain
from pathlib import Path
from shutil import rmtree
Expand Down Expand Up @@ -60,13 +61,38 @@
MISSING,
AnyByStrDict,
JSONSerializable,
Operation,
ParamSpec,
RelativePath,
StrOrPath,
)
from .user_data import DEFAULT_DATA, AnswersMap, Question
from .vcs import get_git

_T = TypeVar("_T")
_P = ParamSpec("_P")

_operation: ContextVar[Operation] = ContextVar("_operation")


def as_operation(value: Operation) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]:
"""Decorator to set the current operation context, if not defined already.

This value is used to template specific configuration options.
"""

def _decorator(func: Callable[_P, _T]) -> Callable[_P, _T]:
@wraps(func)
def _wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _T:
token = _operation.set(_operation.get(value))
try:
return func(*args, **kwargs)
finally:
_operation.reset(token)

return _wrapper

return _decorator


@dataclass(config=ConfigDict(extra="forbid"))
Expand Down Expand Up @@ -243,7 +269,7 @@ def _cleanup(self) -> None:
for method in self._cleanup_hooks:
method()

def _check_unsafe(self, mode: Literal["copy", "update"]) -> None:
def _check_unsafe(self, mode: Operation) -> None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need to pass ˋmodeˋ if it is in ˋself.operationˋ, right?

Copy link
Contributor Author

@lkubb lkubb Aug 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Initially I did it like that, but noticed that this would cause a behavior change (at least in theory):

During _apply_update(), self.operation is update, but it calls on run_copy() several times, which would pass copy to _check_unsafe() before this patch.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. However, that is correct. You will notice that there are calls to replace. In those calls, you can replace some configuration for the sub-worker that is created. Could you please try doing it that way?

Copy link
Contributor Author

@lkubb lkubb Sep 22, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I'm not sure I'm following. First, let me sum up:

  • Introducing _copier_conf.operation means we have an attribute on the worker representing the current high-level (user-requested) operation.
  • You're proposing to use this reference for _check_unsafe instead of the parameter.
  • I noted that doing this will change how _check_unsafe behaves during the individual copy operations that run during an update, where the high-level operation is update, but the low-level one is copy, advocating for keeping the parameter.

I'm already using replace for overriding the operation during update. Are you saying the high-level operation during the individual copy operations should be copy? Because that would mean _copier_conf.operation is always copy during template rendering, i.e. defeat the purpose of this feature.

lkubb marked this conversation as resolved.
Show resolved Hide resolved
"""Check whether a template uses unsafe features."""
if self.unsafe:
return
Expand Down Expand Up @@ -296,8 +322,10 @@ def _execute_tasks(self, tasks: Sequence[Task]) -> None:
Arguments:
tasks: The list of tasks to run.
"""
operation = _operation.get()
for i, task in enumerate(tasks):
extra_context = {f"_{k}": v for k, v in task.extra_vars.items()}
extra_context["_operation"] = operation

if not cast_to_bool(self._render_value(task.condition, extra_context)):
continue
Expand Down Expand Up @@ -327,7 +355,7 @@ def _execute_tasks(self, tasks: Sequence[Task]) -> None:
/ Path(self._render_string(str(task.working_directory), extra_context))
).absolute()

extra_env = {k.upper(): str(v) for k, v in task.extra_vars.items()}
extra_env = {k[1:].upper(): str(v) for k, v in extra_context.items()}
with local.cwd(working_directory), local.env(**extra_env):
subprocess.run(task_cmd, shell=use_shell, check=True, env=local.env)

Expand Down Expand Up @@ -588,7 +616,14 @@ def _pathjoin(
@cached_property
def match_exclude(self) -> Callable[[Path], bool]:
"""Get a callable to match paths against all exclusions."""
return self._path_matcher(self.all_exclusions)
# Include the current operation in the rendering context.
# Note: This method is a cached property, it needs to be regenerated
# when reusing an instance in different contexts.
extra_context = {"_operation": _operation.get()}
return self._path_matcher(
self._render_string(exclusion, extra_context=extra_context)
for exclusion in self.all_exclusions
)

@cached_property
def match_skip(self) -> Callable[[Path], bool]:
Expand Down Expand Up @@ -818,6 +853,7 @@ def template_copy_root(self) -> Path:
return self.template.local_abspath / subdir

# Main operations
@as_operation("copy")
def run_copy(self) -> None:
"""Generate a subproject from zero, ignoring what was in the folder.

Expand All @@ -828,6 +864,11 @@ def run_copy(self) -> None:

See [generating a project][generating-a-project].
"""
with suppress(AttributeError):
# We might have switched operation context, ensure the cached property
# is regenerated to re-render templates.
del self.match_exclude

self._check_unsafe("copy")
self._print_message(self.template.message_before_copy)
self._ask()
Expand All @@ -854,6 +895,7 @@ def run_copy(self) -> None:
# TODO Unify printing tools
print("") # padding space

@as_operation("copy")
def run_recopy(self) -> None:
"""Update a subproject, keeping answers but discarding evolution."""
if self.subproject.template is None:
Expand All @@ -864,6 +906,7 @@ def run_recopy(self) -> None:
with replace(self, src_path=self.subproject.template.url) as new_worker:
new_worker.run_copy()

@as_operation("update")
def run_update(self) -> None:
"""Update a subproject that was already generated.

Expand Down Expand Up @@ -911,6 +954,11 @@ def run_update(self) -> None:
print(
f"Updating to template version {self.template.version}", file=sys.stderr
)
with suppress(AttributeError):
# We might have switched operation context, ensure the cached property
# is regenerated to re-render templates.
del self.match_exclude

self._apply_update()
self._print_message(self.template.message_after_update)

Expand Down
7 changes: 7 additions & 0 deletions copier/types.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Complex types, annotations, validators."""

import sys
from pathlib import Path
from typing import (
Annotated,
Expand All @@ -16,6 +17,11 @@

from pydantic import AfterValidator

if sys.version_info >= (3, 10):
from typing import ParamSpec as ParamSpec
else:
from typing_extensions import ParamSpec as ParamSpec

Check warning on line 23 in copier/types.py

View check run for this annotation

Codecov / codecov/patch

copier/types.py#L23

Added line #L23 was not covered by tests

# simple types
StrOrPath = Union[str, Path]
AnyByStrDict = Dict[str, Any]
Expand All @@ -35,6 +41,7 @@
Env = Mapping[str, str]
MissingType = NewType("MissingType", object)
MISSING = MissingType(object())
Operation = Literal["copy", "update"]


# Validators
Expand Down
17 changes: 17 additions & 0 deletions docs/configuring.md
Original file line number Diff line number Diff line change
Expand Up @@ -893,6 +893,18 @@ to know available options.

The CLI option can be passed several times to add several patterns.

Each pattern can be templated using Jinja.

!!! example

Templating `exclude` patterns using `_operation` allows to have files
that are rendered once during `copy`, but are never updated:

```yaml
_exclude:
- "{% if _operation == 'update' -%}src/*_example.py{% endif %}"
```

!!! info

When you define this parameter in `copier.yml`, it will **replace** the default
Expand Down Expand Up @@ -1351,6 +1363,8 @@ configuring `secret: true` in the [advanced prompt format][advanced-prompt-forma
exist, but always be present. If they do not exist in a project during an `update`
operation, they will be recreated.

Each pattern can be templated using Jinja.

!!! example

For example, it can be used if your project generates a password the 1st time and
Expand Down Expand Up @@ -1501,6 +1515,9 @@ other items not present.
- [invoke, end-process, "--full-conf={{ _copier_conf|to_json }}"]
# Your script can be run by the same Python environment used to run Copier
- ["{{ _copier_python }}", task.py]
# Run a command during the initial copy operation only, excluding updates
- command: ["{{ _copier_python }}", task.py]
when: "{{ _operation == 'copy' }}"
# OS-specific task (supported values are "linux", "macos", "windows" and `None`)
- command: rm {{ name_of_the_project }}/README.md
when: "{{ _copier_conf.os in ['linux', 'macos'] }}"
Expand Down
10 changes: 10 additions & 0 deletions docs/creating.md
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,16 @@ The absolute path of the Python interpreter running Copier.

The name of the project root directory.

## Variables (context-dependent)

Some variables are only available in select contexts:

### `_operation`

The current operation, either `"copy"` or `"update"`.

Availability: [`exclude`](configuring.md#exclude), [`tasks`](configuring.md#tasks)

## Variables (context-specific)

Some rendering contexts provide variables unique to them:
Expand Down
90 changes: 90 additions & 0 deletions tests/test_context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
import json
from pathlib import Path

import pytest
from plumbum import local

import copier

from .helpers import build_file_tree, git_save


def test_exclude_templating_with_operation(
tmp_path_factory: pytest.TempPathFactory,
) -> None:
"""
Ensure it's possible to create one-off boilerplate files that are not
managed during updates via `_exclude` using the `_operation` context variable.
"""
src, dst = map(tmp_path_factory.mktemp, ("src", "dst"))

template = "{% if _operation == 'update' %}copy-only{% endif %}"
with local.cwd(src):
build_file_tree(
{
"copier.yml": f'_exclude:\n - "{template}"',
"{{ _copier_conf.answers_file }}.jinja": "{{ _copier_answers|to_yaml }}",
"copy-only": "foo",
"copy-and-update": "foo",
}
)
git_save(tag="1.0.0")
build_file_tree(
{
"copy-only": "bar",
"copy-and-update": "bar",
}
)
git_save(tag="2.0.0")
copy_only = dst / "copy-only"
copy_and_update = dst / "copy-and-update"

copier.run_copy(str(src), dst, defaults=True, overwrite=True, vcs_ref="1.0.0")
for file in (copy_only, copy_and_update):
assert file.exists()
assert file.read_text() == "foo"

with local.cwd(dst):
git_save()

copier.run_update(str(dst), overwrite=True)
assert copy_only.read_text() == "foo"
assert copy_and_update.read_text() == "bar"


def test_task_templating_with_operation(
tmp_path_factory: pytest.TempPathFactory, tmp_path: Path
) -> None:
"""
Ensure that it is possible to define tasks that are only executed when copying.
"""
src, dst = map(tmp_path_factory.mktemp, ("src", "dst"))
# Use a file outside the Copier working directories to ensure accurate tracking
task_counter = tmp_path / "task_calls.txt"
with local.cwd(src):
build_file_tree(
{
"copier.yml": (
f"""\
_tasks:
- command: echo {{{{ _operation }}}} >> {json.dumps(str(task_counter))}
when: "{{{{ _operation == 'copy' }}}}"
"""
),
"{{ _copier_conf.answers_file }}.jinja": "{{ _copier_answers|to_yaml }}",
}
)
git_save(tag="1.0.0")

copier.run_copy(str(src), dst, defaults=True, overwrite=True, unsafe=True)
assert task_counter.exists()
assert len(task_counter.read_text().splitlines()) == 1

with local.cwd(dst):
git_save()

copier.run_recopy(dst, defaults=True, overwrite=True, unsafe=True)
assert len(task_counter.read_text().splitlines()) == 2

copier.run_update(dst, defaults=True, overwrite=True, unsafe=True)
assert len(task_counter.read_text().splitlines()) == 2
Loading