#!/usr/bin/python3
import json
import subprocess
import sys
import tempfile

from textual.app import App
from textual.containers import Container, Horizontal, VerticalScroll
from textual.widgets import Tree, Footer, Static


class Grummage(App):
    BINDINGS = [
        ("v", "load_tree_by_vulnerability", "by Vuln"),
        ("n", "load_tree_by_package_name", "by Pkg Name"),
        ("s", "load_tree_by_severity", "by Severity"),
        ("t", "load_tree_by_package_type", "by Type"),
        ("e", "explain_vulnerability", "Explain Vuln"),
        ("q", "quit", "Quit"),
    ]
    def __init__(self, sbom_file=None):
        super().__init__()
        self.sbom_file = sbom_file
        self.vulnerability_report = None
        self.debug_log_file = open("debug_log.txt", "w")

    def quit(self):
        """Exit the application."""
        self.exit()

    def debug_log(self, message):
        """Helper method to write debug messages to log file."""
        self.debug_log_file.write(message + "\n")
        self.debug_log_file.flush()  # Ensure immediate write

    async def on_mount(self):
        self.debug_log("on_mount: Starting application setup")

        # Initialize widgets for the tree view, details display, and status bar
        self.tree_view = Tree("Vulnerabilities")
        self.details_display = Static("Select a node for more details.")
        self.status_bar = Static("Status: Initializing...")

        # Create containers for tree view (left) and details (right) in a horizontal layout
        tree_container = Container(self.tree_view)
        details_container = VerticalScroll(self.details_display)
        # Set widths to maintain a 30/70 side-by-side layout
        tree_container.styles.width = "35%"
        details_container.styles.width = "65%"
        tree_container.styles.height = "98%"
        details_container.styles.height = "98%"

        # Use Horizontal container for side-by-side layout
        main_layout = Horizontal(tree_container, details_container)
        main_layout.styles.height = "98%"

        # Mount the main layout and the status bar at the bottom
        await self.mount(main_layout)
        await self.mount(self.status_bar)
        await self.mount(Footer())
        self.debug_log("on_mount: Layout mounted")

        # Load the SBOM from file or stdin
        await self.load_sbom()

    async def load_sbom(self):
        """Load the SBOM from a file or stdin."""
        if self.sbom_file:
            # Load SBOM from the provided file path
            self.debug_log(f"Loading SBOM from file: {self.sbom_file}")
            sbom_json = self.load_json(self.sbom_file)
        else:
            # Read SBOM from stdin
            self.debug_log("Loading SBOM from stdin")
            try:
                sbom_json = json.load(sys.stdin)
            except json.JSONDecodeError as e:
                self.debug_log(f"Error reading SBOM from stdin: {e}")
                self.status_bar.update("Status: Failed to read SBOM from stdin.")
                return

        # Run Grype analysis on the loaded SBOM JSON
        self.vulnerability_report = self.call_grype(sbom_json)
        if self.vulnerability_report and "matches" in self.vulnerability_report:
            self.load_tree_by_package_name()
            self.status_bar.update("Status: Vulnerability data loaded. Press N, T, V, S to change views, or E to explain.")
            self.debug_log("Vulnerability data loaded into tree")
        else:
            self.status_bar.update("Status: No vulnerabilities found or unable to load data.")
            self.debug_log("No vulnerability data found")

    def load_json(self, file_path):
        """Load SBOM JSON from a file."""
        try:
            with open(file_path, "r") as file:
                return json.load(file)
        except Exception as e:
            self.debug_log(f"Error loading SBOM JSON: {e}")
            return None

    def call_grype(self, sbom_json):
        """Call Grype with the SBOM JSON to generate a vulnerability report."""
        try:
            # Create a temporary file to store the SBOM JSON
            with tempfile.NamedTemporaryFile(delete=False, mode='w', suffix='.json') as temp_file:
                json.dump(sbom_json, temp_file)
                temp_file_path = temp_file.name

            # Call Grype using the temporary JSON file path
            result = subprocess.run(
                ["grype", temp_file_path, "-o", "json"],
                capture_output=True,
                text=True
            )

            # Print stdout and stderr for debugging
            #print("Grype STDOUT:", result.stdout)
            #print("Grype STDERR:", result.stderr)

            if result.returncode != 0:
                print("Grype encountered an error:", result.stderr)
                return None

            # Return the parsed JSON if no errors occurred
            return json.loads(result.stdout)

        except Exception as e:
            print("Error running Grype:", e)
            return None
    
    def load_tree_by_package_name(self):
        """Display vulnerabilities organized by package name."""
        self.tree_view.clear()
        file_name_map = {}
        for match in self.vulnerability_report["matches"]:
            file_name = match["artifact"]["name"]
            file_name_map.setdefault(file_name, []).append(match)

        for file_name, matches in file_name_map.items():
            file_node = self.tree_view.root.add(file_name)
            for match in matches:
                vuln_id = match["vulnerability"]["id"]
                vuln_node = file_node.add(f"{vuln_id}")
                
                # Store detailed info for right-hand pane display
                vuln_node.data = {
                    "id": vuln_id,
                    "package_name": match["artifact"]["name"],
                    "package_version": match["artifact"]["version"],
                    "severity": match["vulnerability"].get("severity", "Unknown"),
                    "fix_version": match["vulnerability"].get("fix", {}).get("versions", ["None"]),
                    "related": match.get("relatedVulnerabilities", [])
                }

    def load_tree_by_type(self):
        """Display vulnerabilities organized by package type, with package names under each type."""
        self.tree_view.clear()
        type_map = {}

        # Organize matches by package type and then by package name
        for match in self.vulnerability_report["matches"]:
            pkg_type = match["artifact"]["type"]
            pkg_name = match["artifact"]["name"]
            type_map.setdefault(pkg_type, {}).setdefault(pkg_name, []).append(match)

        # Build the tree view with the new structure
        for pkg_type, packages in type_map.items():
            type_node = self.tree_view.root.add(pkg_type)  # Add package type node
            for pkg_name, matches in packages.items():
                package_node = type_node.add(pkg_name)  # Add package name node
                for match in matches:
                    vuln_id = match["vulnerability"]["id"]
                    vuln_node = package_node.add(f"{vuln_id}")  # Add vulnerability ID under package name
                    
                    # Store detailed info for right-hand pane display
                    vuln_node.data = {
                        "id": vuln_id,
                        "package_name": match["artifact"]["name"],
                        "package_version": match["artifact"]["version"],
                        "severity": match["vulnerability"].get("severity", "Unknown"),
                        "fix_version": match["vulnerability"].get("fix", {}).get("versions", ["None"]),
                        "related": match.get("relatedVulnerabilities", [])
                    }


    def load_tree_by_vulnerability(self):
        """Display vulnerabilities organized by vulnerability ID."""
        self.tree_view.clear()
        vuln_map = {}
        for match in self.vulnerability_report["matches"]:
            vuln_id = match["vulnerability"]["id"]
            vuln_map.setdefault(vuln_id, []).append(match)

        for vuln_id, matches in vuln_map.items():
            vuln_node = self.tree_view.root.add(vuln_id)
            for match in matches:
                pkg_name = match["artifact"]["name"]
                package_node = vuln_node.add(f"{pkg_name}")
                
                # Store detailed info for right-hand pane display
                package_node.data = {
                    "id": vuln_id,
                    "package_name": match["artifact"]["name"],
                    "package_version": match["artifact"]["version"],
                    "severity": match["vulnerability"].get("severity", "Unknown"),
                    "fix_version": match["vulnerability"].get("fix", {}).get("versions", ["None"]),
                    "related": match.get("relatedVulnerabilities", [])
                }

    def load_tree_by_severity(self):
        """Display vulnerabilities organized by severity, in fixed order."""
        self.tree_view.clear()
        # Define the desired order for severities
        severity_order = ["Critical", "High", "Medium", "Low", "Negligible", "Unknown"]
        
        # Create a dictionary mapping each severity to its matches
        severity_map = {severity: [] for severity in severity_order}
        for match in self.vulnerability_report["matches"]:
            severity = match["vulnerability"].get("severity", "Unknown")
            if severity not in severity_map:
                severity = "Unknown"  # Assign unknown severity if it's not one of the predefined categories
            severity_map[severity].append(match)
        
        # Add nodes in the specified order with full vulnerability data for each node
        for severity in severity_order:
            if severity_map[severity]:  # Only add if there are matches
                severity_node = self.tree_view.root.add(severity)
                for match in severity_map[severity]:
                    vuln_id = match["vulnerability"]["id"]
                    package_name = match["artifact"]["name"]
                    vuln_node = severity_node.add(f"{vuln_id} ({package_name})")
                    
                    # Store detailed info in each node for later access in the right-hand pane
                    vuln_node.data = {
                        "id": vuln_id,
                        "package_name": match["artifact"]["name"],
                        "package_version": match["artifact"]["version"],
                        "severity": severity,
                        "fix_version": match["vulnerability"].get("fix", {}).get("versions", ["None"]),
                        "related": match.get("relatedVulnerabilities", [])
                    }



    async def on_key(self, event):
        """Handle key press events to switch views."""
        key = event.key.lower()
        if key == "n":
            self.load_tree_by_package_name()
            self.status_bar.update("Status: Viewing by package name.")
        elif key == "t":
            self.load_tree_by_type()
            self.status_bar.update("Status: Viewing by package type.")
        elif key == "v":
            self.load_tree_by_vulnerability()
            self.status_bar.update("Status: Viewing by vulnerability ID.")
        elif key == "s":
            self.load_tree_by_severity()
            self.status_bar.update("Status: Viewing by severity.")
        elif key == "e" and self.selected_vuln_id and self.detailed_text:
             self.status_bar.update(f"Status: Explaining {self.selected_vuln_id} in {self.selected_package_name} ({self.selected_package_version})")
             await self.explain_vulnerability(self.selected_vuln_id)


    async def on_tree_node_selected(self, event):
        """Show detailed information for the selected vulnerability."""
        node = event.node
        if node.data:
            details = node.data
            self.selected_vuln_id = details["id"]  # Track the selected vulnerability ID for explanation
            self.selected_package_name = details["package_name"]
            self.selected_package_version = details["package_version"]
            self.status_bar.update(f"Status: Selected {details['id']} in {details['package_name']} ({details['package_version']})")
            detail_text = (
                f"ID: {details['id']}\n"
                f"Package: {details['package_name']} ({details['package_version']})\n"
                f"Severity: {details['severity']}\n"
                f"Fix Version: {', '.join(details['fix_version'])}\n"
                f"Related Vulnerabilities:\n"
            )
            for related in details["related"]:
                detail_text += f"- {related['id']} ({related['dataSource']})\n"
            self.debug_log(f"Displaying details for {details['id']}")
            self.details_display.update(detail_text)
            self.detailed_text = detail_text
        else:
            self.details_display.update("No data found for selected node.")
            self.selected_vuln_id = None  # Clear selected vulnerability ID
            self.selected_package_name = None
            self.selected_package_version = None
            self.debug_log("No data found for selected node.")

    def on_unmount(self):
        """Close the log file when the application exits."""
        self.debug_log_file.close()

    async def explain_vulnerability(self, vuln_id):
        """Call Grype to explain a vulnerability by its ID and display the output."""
        try:
            # First, run Grype on the user-provided SBOM file to get the JSON report
            analyze_result = subprocess.run(
                ["grype", self.sbom_file, "-o", "json"],
                capture_output=True,
                text=True
            )

            # Check if the SBOM analysis was successful
            if analyze_result.returncode != 0:
                self.details_display.update(f"Error analyzing SBOM: {analyze_result.stderr}")
                self.debug_log(f"Error analyzing SBOM for explanation: {analyze_result.stderr}")
                return

            # Run Grype's explain command with the specific vulnerability ID
            explain_result = subprocess.run(
                ["grype", "explain", "--id", vuln_id],
                input=analyze_result.stdout,  # Pass the JSON output from the first run as input
                capture_output=True,
                text=True
            )

            # Check and display the result in the details pane
            if explain_result.returncode == 0:
                explanation = explain_result.stdout
                details = self.detailed_text
                self.details_display.update(f"{details}\nExplanation for {vuln_id}:\n\n{explanation}")
                self.debug_log(f"Displaying explanation for {vuln_id}")
            else:
                self.details_display.update(f"Failed to explain {vuln_id}.\nError: {explain_result.stderr}")
                self.debug_log(f"Error explaining {vuln_id}: {explain_result.stderr}")

        except Exception as e:
            self.details_display.update(f"Error explaining {vuln_id}: {e}")
            self.debug_log(f"Exception in explain_vulnerability: {e}")


if __name__ == "__main__":
    sbom_file = sys.argv[1] if len(sys.argv) > 1 else sys.exit()
    Grummage(sbom_file).run()
