Source code for django_tenant_options.management.commands.removetriggers

"""Management command to generate migrations for removing triggers created by maketriggers command.

This module provides a Django management command that automatically generates database migrations
to remove triggers that were previously created for models inheriting from AbstractSelection.
"""

from __future__ import annotations

import re
from argparse import ArgumentParser
from argparse import RawTextHelpFormatter
from dataclasses import dataclass
from datetime import datetime
from pathlib import Path
from textwrap import dedent

from django.apps import apps
from django.core.management.base import BaseCommand
from django.core.management.base import CommandParser
from django.db.migrations.recorder import MigrationRecorder

from django_tenant_options.models import AbstractSelection


@dataclass(frozen=True)  # Make the dataclass immutable which implements __hash__
class TriggerInfo:
    """Contains information about an identified trigger.

    This class is immutable and hashable, allowing it to be used in sets.
    Each instance is uniquely identified by the combination of
    trigger_name and migration_file.
    """

    trigger_name: str
    migration_file: Path
    model_name: str
    app_label: str

    def __eq__(self, other: object) -> bool:
        """Implement equality comparison.

        Args:
            other: Object to compare with

        Returns:
            True if objects are equal, False otherwise
        """
        if not isinstance(other, TriggerInfo):
            return NotImplemented
        return self.trigger_name == other.trigger_name and self.migration_file == other.migration_file

    def __hash__(self) -> int:
        """Implement hashing.

        Returns:
            Hash value for the object
        """
        return hash((self.trigger_name, str(self.migration_file)))


[docs] class Command(BaseCommand): """Management command to generate migrations that remove previously created triggers. This command scans existing migrations to identify triggers created by the maketriggers command and generates new migrations to remove them. """ help = dedent("""\ Generate migrations to remove triggers previously created by maketriggers command. This command will: 1. Scan existing migrations to find triggers 2. Create new migrations to drop the identified triggers 3. Optionally verify trigger removal if --verify flag is used """)
[docs] def __init__(self, *args, **kwargs): """Initialize command with default values for all configuration options.""" super().__init__(*args, **kwargs) self.app_label: str | None = None self.model_name: str | None = None self.dry_run: bool = False self.migration_dir: str | None = None self.interactive: bool = False self.verbose: bool = False self.verify: bool = False self.last_generated_migration: str | None = None
[docs] def create_parser(self, prog_name: str, subcommand: str, **kwargs) -> CommandParser: """Create a command parser that preserves newlines in help text. Args: prog_name: The name of the program subcommand: The name of the subcommand **kwargs: Additional parser arguments Returns: A CommandParser instance with RawTextHelpFormatter """ parser = super().create_parser(prog_name, subcommand, **kwargs) parser.formatter_class = RawTextHelpFormatter return parser
[docs] def add_arguments(self, parser: ArgumentParser) -> None: """Add command-line arguments to the parser. Args: parser: The argument parser to add arguments to """ parser.add_argument( "--app", type=str, metavar="app_name", help="Specify the app to remove triggers from.", ) parser.add_argument( "--model", type=str, metavar="app_name.ModelName", help="Specify the model to remove triggers for (format: app_label.ModelName).", ) parser.add_argument( "--dry-run", action="store_true", help=( "Simulate the migration creation process without writing any files. " "Use with --verbose to see the migration file content that would be created." ), ) parser.add_argument( "--migration-dir", type=str, metavar="directory", help="Specify a custom directory to save migration files.", ) parser.add_argument( "--interactive", action="store_true", help="Prompt for confirmation before creating each migration.", ) parser.add_argument( "--verbose", action="store_true", help="Provide detailed output of the migration creation process.", ) parser.add_argument( "--verify", action="store_true", help="Verify that identified triggers exist in the database before removing.", )
[docs] def handle(self, *args, **options) -> None: """Execute the command logic based on provided options. Args: *args: Positional arguments **options: Command options as key-value pairs """ self._initialize_options(options) if self.model_name: app_label, model_name = self.model_name.split(".") self._handle_single_model(app_label, model_name) elif self.app_label: self._handle_app_models() else: self._handle_all_models()
def _initialize_options(self, options: dict) -> None: """Initialize command options from parsed arguments. Args: options: Dictionary of command options """ self.app_label = options.get("app") self.model_name = options.get("model") self.dry_run = options.get("dry_run", False) self.migration_dir = options.get("migration_dir") self.interactive = options.get("interactive", False) self.verbose = options.get("verbose", False) self.verify = options.get("verify", False) def _handle_single_model(self, app_label: str, model_name: str) -> None: """Process trigger removal for a single model. Args: app_label: Label of the Django app containing the model model_name: Name of the model class """ model = apps.get_model(app_label, model_name) if not issubclass(model, AbstractSelection) or model == AbstractSelection: self.stdout.write( self.style.WARNING(f"Model {model_name} is not a subclass of AbstractSelection. Skipping...") ) return triggers = self._find_triggers_for_model(app_label, model_name) self._process_triggers(triggers) def _handle_app_models(self) -> None: """Process trigger removal for all eligible models in the specified app.""" app_config = apps.get_app_config(self.app_label) triggers = [] for model in app_config.get_models(): if issubclass(model, AbstractSelection) and model != AbstractSelection: triggers.extend(self._find_triggers_for_model(model._meta.app_label, model.__name__)) self._process_triggers(triggers) def _handle_all_models(self) -> None: """Process trigger removal for all eligible models across all apps.""" triggers = [] for model in apps.get_models(): if issubclass(model, AbstractSelection) and model != AbstractSelection: triggers.extend(self._find_triggers_for_model(model._meta.app_label, model.__name__)) self._process_triggers(triggers) def _find_triggers_for_model(self, app_label: str, model_name: str) -> list[TriggerInfo]: """Find all triggers associated with a specific model. Args: app_label: Label of the Django app containing the model model_name: Name of the model class Returns: List of TriggerInfo objects for found triggers """ triggers = [] migrations_dir = self._get_migrations_dir(app_label) if not migrations_dir.exists(): return triggers for migration_file in migrations_dir.glob("*.py"): content = migration_file.read_text() # Look for trigger names in the migration content trigger_pattern = re.compile(r"DROP TRIGGER IF EXISTS ([^;]+);") model_pattern = re.compile(rf"auto_trigger_{model_name.lower()}|trigger.*{model_name.lower()}") if model_pattern.search(migration_file.name): for match in trigger_pattern.finditer(content): triggers.append( TriggerInfo( trigger_name=match.group(1), migration_file=migration_file, model_name=model_name, app_label=app_label, ) ) return triggers def _process_triggers(self, triggers: list[TriggerInfo]) -> None: """Process the identified triggers and create removal migrations. Args: triggers: List of TriggerInfo objects to process """ if not triggers: self.stdout.write(self.style.WARNING("No triggers found to remove.")) return # Group triggers by app_label to create one migration per app triggers_by_app: dict[str, set[TriggerInfo]] = {} for trigger in triggers: triggers_by_app.setdefault(trigger.app_label, set()).add(trigger) for app_label, app_triggers in triggers_by_app.items(): self._create_removal_migration(app_label, app_triggers) def _create_removal_migration(self, app_label: str, triggers: set[TriggerInfo]) -> None: """Create a migration to remove the specified triggers. Args: app_label: Label of the Django app triggers: Set of TriggerInfo objects for triggers to remove """ migration_name = self._construct_migration_name(app_label) migration_path = self._get_migration_path(app_label, migration_name) if self.dry_run: self._handle_dry_run(migration_path, triggers) return if self.interactive and not self._confirm_creation(app_label, triggers): return migration_content = self._generate_migration_content(app_label, triggers) if not self.dry_run: migration_path.write_text(migration_content) self.stdout.write(self.style.SUCCESS(f"Created migration: {migration_path}")) def _construct_migration_name(self, app_label: str) -> str: """Construct a name for the new migration file. Args: app_label: Label of the Django app Returns: Constructed migration name """ last_migration = MigrationRecorder.Migration.objects.filter(app=app_label).order_by("applied").last() if last_migration and (match := re.match(r"^(\d+)_", last_migration.name)): number = str(int(match.group(1)) + 1).zfill(4) return f"{number}_remove_triggers" return "remove_triggers" def _get_migration_path(self, app_label: str, migration_name: str) -> Path: """Get the full path for the new migration file. Args: app_label: Label of the Django app migration_name: Name of the migration file Returns: Path object for the migration file """ migrations_dir = self._get_migrations_dir(app_label) migrations_dir.mkdir(exist_ok=True) return migrations_dir / f"{migration_name}.py" def _get_migrations_dir(self, app_label: str) -> Path: """Get the migrations directory for an app. Args: app_label: Label of the Django app Returns: Path object for the migrations directory """ if self.migration_dir: return Path(self.migration_dir) return Path(apps.get_app_config(app_label).path) / "migrations" def _handle_dry_run(self, migration_path: Path, triggers: set[TriggerInfo]) -> None: """Handle dry run mode for migration creation. Args: migration_path: Path where the migration would be created triggers: Set of TriggerInfo objects for triggers to remove """ self.stdout.write(self.style.SUCCESS(f"[DRY RUN] Would create migration: {migration_path}")) if self.verbose: self.stdout.write(f"[DRY RUN] Would remove triggers: {', '.join(t.trigger_name for t in triggers)}") def _confirm_creation(self, app_label: str, triggers: set[TriggerInfo]) -> bool: """Prompt for user confirmation in interactive mode. Args: app_label: Label of the Django app triggers: Set of TriggerInfo objects for triggers to remove Returns: Boolean indicating if the user confirmed """ trigger_list = "\n ".join(t.trigger_name for t in triggers) return ( input(f"\nWill remove the following triggers from {app_label}:\n {trigger_list}\nProceed? (y/n): ").lower() == "y" ) def _generate_migration_content(self, app_label: str, triggers: set[TriggerInfo]) -> str: """Generate the content for the migration file. Args: app_label: Label of the Django app triggers: Set of TriggerInfo objects for triggers to remove Returns: String containing the migration file content """ timestamp = datetime.now().strftime("%Y-%m-%d %H:%M") # Get the last migration for dependencies last_migration = MigrationRecorder.Migration.objects.filter(app=app_label).order_by("applied").last() last_migration_name = last_migration.name if last_migration else None operations = [] for trigger in triggers: operations.append( f" migrations.RunSQL(\n" f" sql='DROP TRIGGER IF EXISTS {trigger.trigger_name};',\n" f" reverse_sql='', # No reverse operation as this removes triggers\n" f" )," ) return dedent(f"""\ # Generated by django-tenant-options on {timestamp} from django.db import migrations class Migration(migrations.Migration): \"\"\"Removes triggers previously created by django-tenant-options.\"\"\" dependencies = [ ('{app_label}', '{last_migration_name}'), ] operations = [ {"".join(operations)} ] """)