#!/usr/bin/env python3
"""
Chroot into a block device or directory.

Usage:
    chroot-partition <device_or_path> [command] [-s] [-p passphrase]

Options:
    -s, --stay          Stay in chroot after script (default: False)
    -p, --passphrase    Passphrase for LUKS-encrypted partitions

Author: Arjen Balfoort
Date: 2026-03-13
Dependencies: cryptsetup
"""

import subprocess
import os
import sys
import tempfile
import shutil
import argparse
from typing import Optional

def is_valid_chroot_target(target):
    """Check if the target is a valid chroot environment."""
    if target.startswith("/dev/"):
        # For block devices, check if it's a valid block device
        if not os.path.exists(target):
            return False
        # Try to mount it (temporarily) to check if it's a valid filesystem
        # (This is tricky; you might want to skip this and let mount fail later)
        return True
    else:
        # For directories, check for essential files/dirs
        essential_paths = [
            os.path.join(target, "bin"),
            os.path.join(target, "lib"),
            os.path.join(target, "etc"),
            os.path.join(target, "bin/sh"),
        ]
        return all(os.path.exists(path) for path in essential_paths)

def get_mapped_path(device: str) -> Optional[str]:
    """
    Check if a LUKS device is already mapped.

    Args:
        device (str): Device path to check.

    Returns:
        str: Path to the mapped device, or None if not found.
    """
    # List all mapped devices in /dev/mapper
    for mapped_dev in os.listdir("/dev/mapper"):
        mapped_path = f"/dev/mapper/{mapped_dev}"
        # Check if this mapped device corresponds to the given LUKS device
        result = subprocess.run(
            ["cryptsetup", "status", mapped_path],
            capture_output=True,
            text=True,
            check=False
        )
        if result.returncode == 0 and device in result.stdout:
            return mapped_path
    return None

def main():
    """ Main function """
    parser = argparse.ArgumentParser()
    parser.add_argument("device_or_path",
                        help="Device or path to chroot into")
    # nargs="?" makes the positional argument optional.
    parser.add_argument("command",
                        nargs="?",
                        default=None,
                        help="Command to run in chroot. When given: stay = False, else stay = True")
    parser.add_argument("-s","--stay",
                        action="store_true",
                        default=None,
                        help="Stay in chroot after script (default: False)")
    parser.add_argument("-p", "--passphrase",
                        default=None,
                        help="Passphrase for LUKS-encrypted partitions")

    args = parser.parse_args()

    # Check if the target is a valid chroot environment
    if not is_valid_chroot_target(args.device_or_path):
        print(f"Error: {args.device_or_path} is not a valid chroot target.")
        print("For block devices, ensure it's a valid filesystem.")
        print("For directories, ensure it contains /bin, /lib, /etc, and /bin/sh.")
        sys.exit(1)

    # Set default for stay based on command
    if args.stay is None:
        args.stay = args.command is None  # stay=True if no command, else stay=False

    # Check if we are root
    if os.geteuid() != 0:
        # Re-execute with pkexec
        subprocess.run(["pkexec", sys.executable] + sys.argv, check=True)
        sys.exit(0)

    # Now running as root
    chroot_root_device(args.device_or_path,
                       args.command,
                       stay_in_chroot=args.stay,
                       passphrase=args.passphrase)

def chroot_root_device(path_or_device: str,
                       command: str=None,
                       stay_in_chroot: bool=True,
                       passphrase: str=None) -> bool:
    """
    Chroot into the root partition of the specified device or path.

    Args:
        path_or_device (str): Device or path to chroot into
        command (str, optional): Command to run in chroot. Defaults to None.
        stay_in_chroot (bool, optional): Stay in chroot after script. Defaults to True.
        passphrase (str, optional): Passphrase for LUKS-encrypted partitions. Defaults to None.
    """
    target_dir = path_or_device
    script_path = None  # Initialize later
    pname = None

    # Handle LUKS-encrypted partitions
    if path_or_device.startswith("/dev/"):
        # Check if the partition is LUKS-encrypted
        fs_type = subprocess.run(
            ["blkid", "-s", "TYPE", "-o", "value", path_or_device],
            capture_output=True,
            text=True,
            check=True
        ).stdout.strip()
        if fs_type == "crypto_LUKS":
            # Check if the partition is already mapped (under any name)
            mapper_path = get_mapped_path(path_or_device)
            if not mapper_path:
                # Map the LUKS partition to a predictable name
                pname = f"luks-{os.path.basename(path_or_device)}"
                cmd = ["cryptsetup", "open", "--type", "luks", path_or_device, pname]
                if passphrase:
                    subprocess.run(cmd, input=passphrase, text=True, check=True)
                else:
                    # Prompt for passphrase interactively
                    subprocess.run(cmd, check=True)
                mapper_path = f"/dev/mapper/{pname}"
            path_or_device = mapper_path

        # Mount the (possibly decrypted) partition
        target_dir = tempfile.mkdtemp()
        os.makedirs(target_dir, exist_ok=True)
        try:
            subprocess.run(["mount", path_or_device, target_dir], check=True)
        except subprocess.CalledProcessError as e:
            print(f"Failed to mount {path_or_device}: {e}")
            return False

    # Create necessary directories for os systems
    try:
        os.makedirs(os.path.join(target_dir, "dev/shm"), exist_ok=True)
        os.makedirs(os.path.join(target_dir, "var/lock"), exist_ok=True)
    except OSError as e:
        print(f"{target_dir} is not writable: {e}")
        return False

    # Mount essential filesystems
    subprocess.run(["mount", "-t", "proc", "/proc",
                    os.path.join(target_dir, "proc")], check=True)
    subprocess.run(["mount", "-t", "sysfs", "/sys",
                    os.path.join(target_dir, "sys")], check=True)
    subprocess.run(["mount", "-o", "bind", "/dev",
                    os.path.join(target_dir, "dev")], check=True)
    subprocess.run(["mount", "-t", "devtmpfs", "devtmpfs",
                    os.path.join(target_dir, "dev")], check=True)
    subprocess.run(["mount", "-t", "devpts", "devpts",
                    os.path.join(target_dir, "dev/pts")], check=True)
    subprocess.run(["mount", "-t", "tmpfs", "tmpfs",
                    os.path.join(target_dir, "dev/shm")], check=True)
    subprocess.run(["mount", "-t", "tmpfs", "tmpfs",
                    os.path.join(target_dir, "run")], check=True)
    if os.path.exists(os.path.join(target_dir, "sys/firmware/efi/efivars")):
        subprocess.run(["mount", "-t", "efivarfs", "efivarfs",
                        os.path.join(target_dir, "sys/firmware/efi/efivars")], check=True)

    # Now target_dir is the mounted directory (or the original path)
    script_path = os.path.join(target_dir, "setup-chroot.sh")

    # Define the content of setup-chroot.sh
    setup_script_content = """#!/bin/bash
# Mount all filesystems listed in /etc/fstab
if [ -e /etc/fstab ]; then
    # Run mount -a twice:
    # - first to make sure /boot is mounted before /boot/efi
    # - second to make sure /boot/efi is mounted
    mount -a 2>/dev/null
    mount -a
fi
"""

    # Add the command if provided
    # Propagate any exit code
    if command:
        setup_script_content += f"\n{command} || exit 1\n"

    # Write the setup script to the chroot directory
    with open(script_path, "w", encoding="utf-8") as f:
        f.write(setup_script_content)
    os.chmod(script_path, 0o755)

    # Copy /etc/resolv.conf
    resolve_org = None
    if os.path.exists("/etc/resolv.conf"):
        resolve = os.path.join(target_dir, "etc/resolv.conf")
        resolve_org = os.path.join(target_dir, "etc/resolv.conf.orig")
        if os.path.exists(resolve):
            os.rename(resolve, resolve_org)
        shutil.copy("/etc/resolv.conf", resolve)

    # Copy /etc/hosts
    hosts_org = None
    if os.path.exists("/etc/hosts"):
        hosts = os.path.join(target_dir, "etc/hosts")
        hosts_org = os.path.join(target_dir, "etc/hosts.orig")
        if os.path.exists(hosts):
            os.rename(hosts, hosts_org)
        shutil.copy("/etc/hosts", hosts)

    try:
        # Create a private mount namespace on the host
        subprocess.run(["unshare", "-m", "sh", "-c",
                       f"chroot {target_dir} /setup-chroot.sh"], check=True)

        # Optionally, stay in the chroot environment
        if stay_in_chroot:
            print(f"\nEntering chroot environment at {target_dir}. Exit with 'exit' or Ctrl+D.")
            subprocess.run(["unshare", "-m", "chroot", target_dir, "/bin/bash"], check=False)

    except subprocess.CalledProcessError as e:
        print(f"Error: {e}")
        return False
    except Exception as e:
        print(f"Unexpected error: {e}")
        return False
    finally:
        # Restore /etc/resolv.conf
        if os.path.exists(resolve_org):
            os.remove(resolve)
            os.rename(resolve_org, resolve)

        # Restore /etc/hosts
        if os.path.exists(hosts_org):
            os.remove(hosts)
            os.rename(hosts_org, hosts)

        # Cleanup: remove script
        if os.path.exists(script_path):
            try:
                os.remove(script_path)
            except OSError:
                pass

        # Unmount the device if a temp dir was used
        if path_or_device.startswith("/dev/"):
            subprocess.run(["umount", "-R", target_dir], check=False, capture_output=True)
            try:
                os.rmdir(target_dir)
            except OSError:
                pass

        # Unmap the LUKS partition
        if pname is not None:
            subprocess.run(["cryptsetup", "close", pname], check=False)

    return True

if __name__ == "__main__":
    main()
