# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

import contextlib
import multiprocessing
import queue  # for Empty
import sys
import traceback
from collections.abc import Callable, Sequence
from dataclasses import dataclass
from io import StringIO
from typing import Any

PROCESS_KILLED = -9
PROCESS_NO_RESULT = -999


# Similar to https://docs.python.org/3/library/subprocess.html#subprocess.CompletedProcess
# (args, check_returncode() are intentionally not supported here.)
@dataclass
class CompletedProcess:
    returncode: int
    stdout: str
    stderr: str


class ChildProcessWrapper:
    def __init__(
        self,
        result_queue: Any,
        target: Callable[..., None],
        args: Sequence[Any] | None,
        kwargs: dict[str, Any] | None,
    ) -> None:
        self.target = target
        self.args = () if args is None else args
        self.kwargs = {} if kwargs is None else kwargs
        self.result_queue = result_queue

    def __call__(self) -> None:
        # Capture stdout/stderr
        old_stdout = sys.stdout
        old_stderr = sys.stderr
        sys.stdout = StringIO()
        sys.stderr = StringIO()

        try:
            self.target(*self.args, **self.kwargs)
            returncode = 0
        except SystemExit as e:  # Handle sys.exit()
            returncode = e.code if isinstance(e.code, int) else 0
        except BaseException:
            traceback.print_exc()
            returncode = 1
        finally:
            # Collect outputs and restore streams
            stdout = sys.stdout.getvalue()
            stderr = sys.stderr.getvalue()
            sys.stdout = old_stdout
            sys.stderr = old_stderr
            with contextlib.suppress(Exception):
                self.result_queue.put((returncode, stdout, stderr))


def run_in_spawned_child_process(
    target: Callable[..., None],
    *,
    args: Sequence[Any] | None = None,
    kwargs: dict[str, Any] | None = None,
    timeout: float | None = None,
    rethrow: bool = False,
) -> CompletedProcess:
    """Run `target` in a spawned child process, capturing stdout/stderr.

    The provided `target` must be defined at the top level of a module, and must
    be importable in the spawned child process. Lambdas, closures, or interactively
    defined functions (e.g., in Jupyter notebooks) will not work.

    If `rethrow=True` and the child process exits with a nonzero code,
    raises ChildProcessError with the captured stderr.
    """
    ctx = multiprocessing.get_context("spawn")
    result_queue = ctx.Queue()
    process = ctx.Process(target=ChildProcessWrapper(result_queue, target, args, kwargs))
    process.start()

    try:
        process.join(timeout)
        if process.is_alive():
            process.terminate()
            process.join()
            result = CompletedProcess(
                returncode=PROCESS_KILLED,
                stdout="",
                stderr=f"Process timed out after {timeout} seconds and was terminated.",
            )
        else:
            try:
                returncode, stdout, stderr = result_queue.get(timeout=1.0)
            except (queue.Empty, EOFError):
                result = CompletedProcess(
                    returncode=PROCESS_NO_RESULT,
                    stdout="",
                    stderr="Process exited or crashed before returning results.",
                )
            else:
                result = CompletedProcess(
                    returncode=returncode,
                    stdout=stdout,
                    stderr=stderr,
                )

        if rethrow and result.returncode != 0:
            raise ChildProcessError(
                f"Child process exited with code {result.returncode}.\n"
                "--- stderr-from-child-process ---\n"
                f"{result.stderr}"
                "<end-of-stderr-from-child-process>\n"
            )

        return result

    finally:
        try:
            result_queue.close()
            result_queue.join_thread()
        except Exception:  # noqa: S110
            pass
        if process.is_alive():
            process.kill()
            process.join()
