#! /usr/bin/env python3
#
# Run semgrep on a series of pairs (rules, repo) with different options,
# and report the time it takes. Optionally upload the results to the semgrep
# dashboard.
# This script assumes the presence of 'semgrep' in the PATH.
#
import argparse
import copy
import json
import logging
import os
import re
import subprocess
import sys
import time
import urllib.request
from contextlib import contextmanager
from pathlib import Path
from typing import Iterator
from typing import List
from typing import Optional
from typing import Tuple

import requests
from config import BenchmarkRunSetupData
from config import SemgrepBenchmarkConfig
from constants import BPS_ENDPOINT
from constants import DASHBOARD_URL
from constants import LPS_ENDPOINT
from constants import PREP_FILE_TEMPLATE
from constants import RULE_CONFIG_CACHE_DIR
from constants import STATS_URL
from constants import STD
from corpus import Corpus
from corpus import CORPUS_LOCATION
from corpus import DUMMY_CORPUSES
from corpus import INTERNAL_CORPUSES
from corpus import LARGE_CORPUSES
from corpus import MEDIUM_CORPUSES
from corpus import SMALL_CORPUSES
from repository_time_per_rule import RepositoryTimePerRule
from variant import SEMGREP_VARIANTS
from variant import SemgrepVariant
from variant import STD_VARIANTS

logger = logging.getLogger(__file__)
logger.setLevel(logging.INFO)
handler = logging.StreamHandler(stream=sys.stderr)
handler.setFormatter(
    logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
)
logger.addHandler(handler)

PERF_PATH = "OSS/perf"
PERF_RULE_PATH = "OSS.perf"


# This class allows us to put semgrep results in a set and compute set
# differences while saving the original JSON dictionary
class SemgrepResult:
    def __init__(self, dict: dict) -> None:
        self.res = dict
        dict2 = copy.deepcopy(dict)

        # We use self.str to compare dicts, so change the below
        # to abstract away differences

        # The benchmarks are primarily focused on ensuring
        # the matches don't change. Previously, we masked
        # the extra fields individually when necessary, but
        # as they change frequently and often contain info
        # that varies by run, it's easier to just mask them
        if "extra" in dict2:
            dict2["extra"] = "<masked in benchmarks>"

        if "check_id" in dict2:
            relative_path_start = dict2["check_id"].find(PERF_RULE_PATH)
            dict2["check_id"] = dict2["check_id"][relative_path_start:]

        json_str = json.dumps(dict2, sort_keys=True)

        # Abbreviate paths to avoid file-system differences
        # Do this here because path appears multiple times
        self.str = re.sub(
            rf'"path": "[A-Za-z/\-]*{PERF_PATH}',
            f'"path": "<masked in benchmarks>/{PERF_PATH}',
            json_str,
        )

    def __eq__(self, other: object) -> bool:
        if not isinstance(other, SemgrepResult):
            raise NotImplementedError

        return self.str == other.str

    def __lt__(self, other: object) -> bool:
        if not isinstance(other, SemgrepResult):
            raise NotImplementedError

        return self.str > other.str

    def __hash__(self) -> int:
        return hash(self.str)


# Add support for: with chdir(DIR): ...
@contextmanager
def chdir(dir: str) -> Iterator[None]:
    old_dir = os.getcwd()
    os.chdir(dir)
    try:
        yield
    finally:
        os.chdir(old_dir)


def upload(url: str, value: str) -> None:
    logger.info(f"Uploading to {url}")
    r = urllib.request.urlopen(  # nosem
        url=url,
        data=value.encode("ascii"),
    )
    print(r.read().decode())


# coupling: timings has type Semgrep_output_v1.profile
def upload_result(
    variant_name: str, metric_name: str, value: float, timings: dict
) -> None:
    # Upload overall runtime of the benchmark
    metric_url = f"{DASHBOARD_URL}/api/metric/{metric_name}"
    upload(metric_url, str(value))

    if variant_name == STD and timings is not None:
        # Compute bps from semgrep timing data
        assert "rules" in timings
        assert "profiling_times" in timings
        assert "total_bytes" in timings
        prof_times = timings["profiling_times"]
        assert "total_time" in prof_times
        num_rules = len(timings["rules"])
        bps = timings["total_bytes"] / (prof_times["total_time"] * num_rules)

        bps_url = f"{DASHBOARD_URL}/api/metric/{BPS_ENDPOINT}"
        upload(bps_url, str(bps))

        # Similarly, compute lps
        assert "targets" in timings
        num_lines = 0
        for target in timings["targets"]:
            assert "path" in target
            try:
                fd = open(target["path"])
                f = fd.read()
                try:
                    num_lines += sum(1 for _ in f)
                except UnicodeDecodeError:
                    pass
                finally:
                    fd.close()

            except FileNotFoundError:
                # Semgrep creates artificial temporary target files (for
                # example, the Flask lifecycle harness) to aid in analysis under
                # some circumstances. It cleans these up after it runs. If we
                # try to open them here, we'll get an error.
                pass

        lps = num_lines / (prof_times["total_time"] * num_rules)

        lps_url = f"{DASHBOARD_URL}/api/metric/{LPS_ENDPOINT}"
        upload(lps_url, str(lps))

        # Upload timing data as a json
        logger.info(f"Uploading timing data to {STATS_URL}")
        headers = {"content-type": "application/json"}
        r = requests.post(
            STATS_URL, data=json.dumps(timings), headers=headers, timeout=30
        )
        print(r.content)


def standardize_findings(findings: dict, include_time: bool) -> Tuple[dict, dict]:
    if "errors" not in findings:
        msg = json.dumps(findings, indent=4) + "\n\nDid not find expected key 'errors'"
        raise Exception(msg)
    if "results" not in findings:
        msg = json.dumps(findings, indent=4) + "\n\nDid not find expected key 'results'"
        raise Exception(msg)
    if include_time and "time" not in findings:
        msg = json.dumps(findings, indent=4) + "\n\nDid not find expected key 'time'"
        raise Exception(msg)
    results = {
        "errors": findings["errors"],
        "results": {SemgrepResult(i) for i in findings["results"]},
    }
    timings = findings["time"] if include_time else None
    return results, timings


def jsonify_findings(findings: dict) -> dict:
    """
    Turn findings into a jsonifiable dictionary.
    This will later be used in scripts/compare_perf to confirm correctness of benchmarks.
    Important that the results are sorted, so that the results are order-independent
    """
    return {
        "errors": findings["errors"],
        "results": [json.loads(finding.str) for finding in sorted(findings["results"])],
    }


def output_differences(
    findings: set, std_findings: set, variant: str
) -> Tuple[int, int]:
    def output_diff(diff: set) -> None:
        for d in sorted(diff):
            print(json.dumps(d.res, sort_keys=True, indent=4))

    f_diff = findings.difference(std_findings)
    s_diff = std_findings.difference(findings)
    fd_len = len(f_diff)
    sd_len = len(s_diff)
    print("In", variant, "but not std", fd_len, "findings :")
    output_diff(f_diff)
    print("In std but not", variant, sd_len, "findings :")
    output_diff(s_diff)
    return fd_len, sd_len


def run_semgrep(
    docker: str,
    corpus: Corpus,
    variant: SemgrepVariant,
    include_time: bool,
    hard_timeout: int = 0,
    trace: bool = False,
    trace_endpoint: Optional[str] = None,
) -> Tuple[float, dict]:
    args = []
    common_args = [
        "--strict",
        "--json",
        "--timeout",
        "0",
        "--metrics=off",
        "--no-git-ignore",  # because files in bench/*/input/ are git-ignored
    ]
    if docker:
        # Absolute paths are required by docker for mounting volumes, otherwise
        # they end up empty inside the container.
        args = [
            "docker",
            "run",
            "--rm",  # cleanup after ourselves
            "-v",
            os.path.abspath(corpus.rule_dir) + ":/rules",
            "-v",
            os.path.abspath(corpus.target_dir) + ":/targets",
            # some containers use /bin/bash as the entrypoint,
            # so we override it here
            "--entrypoint",
            "semgrep",
            docker,
            "--quiet",
            "--config",
            "/rules",
            "/targets",
        ]
    else:
        # Absolute paths for rules and targets are required by semgrep
        # when running within the semgrep docker container.
        args = [
            "semgrep",
            "--config",
            os.path.abspath(corpus.rule_dir),
        ]
    args.extend(common_args)
    args.extend(corpus.semgrep_options)
    if variant.semgrep_extra != "":
        args.extend([variant.semgrep_extra])
    if include_time:
        args.extend(["--time"])
    if trace:
        args.extend(["--trace"])
    if trace_endpoint:
        args.extend(["--trace-endpoint semgrep-dev", trace_endpoint])

    # Finally, add target
    if not docker:
        # We already did this for docker as /targets above.
        args.append(os.path.abspath(corpus.target_dir))

    print(f"current directory: {os.getcwd()}")
    print("semgrep command: {}".format(" ".join(args)))
    os.environ["SEMGREP_CORE_EXTRA"] = variant.semgrep_core_extra
    print(f"extra arguments for semgrep-core: '{variant.semgrep_core_extra}'")

    t1 = time.time()
    res = subprocess.run(  # nosem
        args, capture_output=True, timeout=hard_timeout or None
    )
    t2 = time.time()

    status = res.returncode
    print(f"semgrep exit status: {status}")
    if status == 0:
        print("success")
    elif status == 3:
        print("warning: some files couldn't be parsed")
    else:
        print("************* Semgrep stdout *************")
        print(res.stdout)
        print("************* Semgrep stderr *************")
        print(res.stderr)
        res.check_returncode()

    try:
        stdout_decoded = res.stdout.decode("utf-8")
        semgrep_results = json.loads(stdout_decoded)
    except (ValueError, json.JSONDecodeError):
        logger.error(
            f"Unable to decode the Semgrep result as JSON. Exiting.\n******\n{res.stdout}"
        )
        sys.exit(1)

    return t2 - t1, semgrep_results


def prepare_rule_cache_for_this_run(
    setup_data: BenchmarkRunSetupData,
    clean: bool = False,
) -> Tuple[Path, List[Path]]:
    rule_cache_for_this_run: Path = RULE_CONFIG_CACHE_DIR / setup_data.run_name
    rule_cache_for_this_run.mkdir(parents=True, exist_ok=True)
    logger.info(
        f"Rule cache for run {setup_data.run_name} created at {rule_cache_for_this_run}"
    )
    rule_config_paths = []
    for rule_config in setup_data.rule_configs:
        logger.info(f"Checking for rule config '{rule_config}' in cache")
        normalized_rule_config = rule_config.normalize_rule_config_name()
        rule_config_cache_path = rule_cache_for_this_run / normalized_rule_config
        # check if rules are in cache on filesystem
        # if not, fetch the rules into the cache
        if not rule_config_cache_path.exists() or clean:
            try:
                rule_config.resolve_to_cache(rule_config_cache_path)
            except ValueError as ve:
                logger.warning(f"{str(ve)}. Skipping this config")
                continue
        else:
            logger.info(
                f"Rule config '{rule_config_cache_path}' already exists, using this one"
            )
        rule_config_paths.append(rule_config_cache_path)
    return rule_cache_for_this_run, rule_config_paths


def prepare_benchmark_run(
    benchmark_config: SemgrepBenchmarkConfig, clean: bool
) -> List[Corpus]:
    """
    Sets up the relevant data for a benchmarking run
    - downloads rule configs from an endpoint if necessary
    - generates a 'prep' file which clones a repo from a URL and checks out the commit
    """
    corpuses: List[Corpus] = []
    for setup_data in benchmark_config.benchmark_setup_data:
        logger.info(f"Setting up benchmark run for run '{setup_data.run_name}'")
        rule_cache_dir, _ = prepare_rule_cache_for_this_run(setup_data, clean)
        for repository in setup_data.repositories:
            repo_name = repository.url.split("/")[-1]
            generate_prep_file(
                setup_data.run_name,
                repo_name,
                rule_cache_dir.absolute(),
                repository.url,
                repository.commit_hash,
            )
            corpuses.append(
                Corpus(
                    setup_data.run_name,
                    rule_cache_dir.absolute(),
                    Path("input") / repo_name,
                    semgrep_options=setup_data.semgrep_options,
                )
            )

    return corpuses


def generate_prep_file(
    name: str,
    repo_name: str,
    rule_cache_dir: Path,
    url: str,
    commit_hash: str,
) -> None:
    # make dirs
    benchdir = Path(f"{CORPUS_LOCATION}/{name}")
    benchdir.mkdir(parents=True, exist_ok=True)
    prep_file = benchdir / "prep"
    # generate prep file
    logger.info(f"Generating 'prep' file at {prep_file}")
    prep_file.touch(0o755, exist_ok=True)
    prep_file.write_text(
        PREP_FILE_TEMPLATE.format(
            name=repo_name,
            rule_cache_dir=rule_cache_dir,
            url=url,
            commit_hash=commit_hash,
        )
    )


def get_semgrep_version():
    return (
        subprocess.run(["semgrep", "--version"], capture_output=True)
        .stdout.decode("utf-8")
        .strip()
    )


def run_benchmarks(
    docker: str,
    dummy: bool,
    small_only: bool,
    all: bool,
    internal: bool,
    std_only: bool,
    config_file: Optional[Path],
    filter_corpus: str,
    filter_variant: str,
    hard_timeout: int,
    output_time_per_rule_json: str,
    plot_benchmarks: bool,
    upload: bool,
    include_time: bool,
    clean: bool,
    trace: bool,
    trace_endpoint: Optional[str],
    summary_file_path: str,
    findings_file_path: str,
    called_dir: str,
) -> None:
    def get_corpuses(config_file: Path) -> List[Corpus]:
        config_file_path = config_file.absolute()
        return prepare_benchmark_run(
            SemgrepBenchmarkConfig.parse_config(config_file_path), clean
        )

    variants = SEMGREP_VARIANTS
    if std_only or output_time_per_rule_json:
        variants = STD_VARIANTS
    if filter_variant:
        variants = [x for x in variants if re.search(filter_variant, x.name) != None]
    # TODO: make benchmarking time-per-rule data work with all variants.
    if output_time_per_rule_json:
        logger.info("Running only the standard variants.")
        variants = STD_VARIANTS

    results_msgs = []
    durations = []
    results: dict = {variant.name: [] for variant in variants}
    saved_findings = []
    output_per_rule_json_results = RepositoryTimePerRule(
        output_file=output_time_per_rule_json
    )

    if dummy:
        corpuses = DUMMY_CORPUSES
    elif internal:
        corpuses = INTERNAL_CORPUSES
    elif small_only:
        corpuses = get_corpuses(SMALL_CORPUSES)
    elif all:
        corpuses = (
            get_corpuses(SMALL_CORPUSES)
            + get_corpuses(MEDIUM_CORPUSES)
            + get_corpuses(LARGE_CORPUSES)
        )
    elif config_file:
        corpuses = get_corpuses(config_file)
    else:
        corpuses = get_corpuses(SMALL_CORPUSES) + get_corpuses(MEDIUM_CORPUSES)
    if filter_corpus:
        corpuses = [x for x in corpuses if re.search(filter_corpus, x.name) != None]

    # Multithread here. Dedicate one core to each benchmark
    for corpus in corpuses:
        with chdir(f"{CORPUS_LOCATION}/{corpus.name}"):
            try:
                corpus.prep()
            except Exception as e:
                logger.warning(
                    f"Could not execute 'prep' file for {corpus.name}. Please check it exists. Will skip this corpus. Error: {e}"
                )
                continue

            std_findings = {}
            options = sorted(
                re.sub(r"\W+", "", option) for option in corpus.semgrep_options
            )
            for variant in variants:
                # Run variant
                name = ".".join(
                    ["semgrep"] + options + ["bench", corpus.name, variant.name]
                )
                metric_name = ".".join([name, "duration"])
                print(f"------ {name} ------")
                try:
                    duration, semgrep_results = run_semgrep(
                        docker,
                        corpus,
                        variant,
                        include_time,
                        hard_timeout,
                        trace,
                        trace_endpoint,
                    )
                except subprocess.TimeoutExpired:
                    logger.warning(
                        f"A hard timeout was set for {hard_timeout} seconds."
                    )
                    logger.warning(
                        f"The test for '{variant.name}' on '{corpus.name}' exceeded the timeout. This test will be skipped."
                    )
                    continue

                if not output_time_per_rule_json is None:
                    output_per_rule_json_results.times_per_file_to_times_per_rule(
                        corpus.name, semgrep_results
                    )

                # Report results
                msg = f"{metric_name} = {duration:.3f} s"
                print(msg)
                results_msgs.append(msg)

                findings, timings = standardize_findings(semgrep_results, include_time)
                durations.append(
                    {
                        "name": name,
                        "time": duration,
                    }
                )
                results[variant.name].append(duration)

                if upload:
                    upload_result(variant.name, metric_name, duration, timings)

                # Check correctness
                num_results = len(findings["results"])
                num_errors = len(findings["errors"])
                print(f"Result: {num_results} findings, {num_errors} parse errors")

                if variant.name == STD:
                    std_findings = findings
                elif findings["results"] ^ std_findings["results"] != set():
                    fd_len, sd_len = output_differences(
                        findings["results"], std_findings["results"], variant.name
                    )
                    results_msgs[
                        -1
                    ] += f" ERROR: {fd_len} extra findings, {sd_len} missing findings"
                elif len(findings["errors"]) > len(std_findings["errors"]):
                    results_msgs[-1] += " WARNING: more errors than std"

            saved_findings.append(
                {
                    "name": ".".join(["semgrep", "bench", corpus.name]),
                    "findings": jsonify_findings(std_findings),
                }
            )

    # Show summary data
    print("\n".join(results_msgs))

    if summary_file_path:
        with chdir(called_dir):
            with open(summary_file_path, "w") as f:
                json.dump(durations, f)

    if findings_file_path:
        with chdir(called_dir):
            with open(findings_file_path, "w") as f:
                json.dump(saved_findings, f)

    if plot_benchmarks:
        import matplotlib.pyplot as plt
        import pandas as pd

        indexes = [corpus.name for corpus in corpuses]
        plotdata = pd.DataFrame(results, index=indexes)
        plotdata.plot(kind="bar")
        plt.show()

    if not output_time_per_rule_json is None:
        with chdir(called_dir):
            output_per_rule_json_results.print_repo_to_times_per_rule()


def main() -> None:
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--docker",
        metavar="DOCKER_IMAGE",
        type=str,
        help="use the specified docker image for semgrep, such as returntocorp/semgrep:develop",
    )
    parser.add_argument(
        "--dummy",
        help="run quick, fake benchmarks for development purposes",
        action="store_true",
    )
    parser.add_argument(
        "--internal",
        help="run internal benchmarks for development purposes",
        action="store_true",
    )
    parser.add_argument(
        "--small-only",
        help="only run small benchmarks (20s per run or less)",
        action="store_true",
    )
    parser.add_argument(
        "--all",
        help="run all benchmarks (takes >1 day)",
        action="store_true",
    )
    parser.add_argument(
        "--config",
        "-f",
        default=None,
        help="run benchmarks according to a config file",
    )
    parser.add_argument(
        "--std-only", help="only run the default semgrep", action="store_true"
    )
    parser.add_argument(
        "--filter-corpus",
        metavar="REGEXP",
        type=str,
        help="run the corpus only if it satisfies the regexp (e.g., 'std|exp')",
    )
    parser.add_argument(
        "--filter-variant",
        metavar="REGEXP",
        type=str,
        help="run the variant only if it satisfies the regexp (e.g., 'dr.*') ",
    )
    parser.add_argument(
        "--upload", help="upload results to semgrep dashboard", action="store_true"
    )
    parser.add_argument(
        "--save-to",
        metavar="FILE_NAME",
        type=str,
        help="save timing summary to the json file given by the argument",
    )
    parser.add_argument(
        "--save-findings-to",
        metavar="FILE_NAME",
        type=str,
        help="save findings to the json file given by the argument",
    )
    parser.add_argument(
        "--plot-benchmarks",
        help="display a graph of the benchmark results",
        action="store_true",
    )
    parser.add_argument(
        "--hard-timeout",
        help="TODO",
        type=int,
        default=0,
    )
    parser.add_argument(
        "--clean",
        help="Regenerate rule cache and prep files",
        action="store_true",
    )
    parser.add_argument(
        "--trace",
        help="Run tracing on the benchmarks",
        action="store_true",
    )
    parser.add_argument(
        "--trace_endpoint",
        metavar="SEMGREP_OTEL_ENDPOINT",
        help="Add endpoint for collecting tracing",
        type=str,
    )
    parser.add_argument(
        "--output-time-per-rule-json",
        nargs="?",
        default=None,
        const="repo_to_rule_time.json",
        help="output a json file that shows timing information for each rule.",
    )
    parser.add_argument("--no-time", help="disable time-checking", action="store_true")
    args = parser.parse_args()

    cur_dir = os.path.dirname(os.path.abspath(__file__))
    called_dir = os.getcwd()

    with chdir(cur_dir):
        config_file = (Path(called_dir) / args.config) if args.config else None
        run_benchmarks(
            docker=args.docker,
            dummy=args.dummy,
            small_only=args.small_only,
            all=args.all,
            internal=args.internal,
            std_only=args.std_only,
            config_file=config_file,
            filter_corpus=args.filter_corpus,
            filter_variant=args.filter_variant,
            hard_timeout=args.hard_timeout,
            output_time_per_rule_json=args.output_time_per_rule_json,
            plot_benchmarks=args.plot_benchmarks,
            upload=args.upload,
            include_time=not args.no_time,
            clean=args.clean,
            trace=args.trace,
            trace_endpoint=args.trace_endpoint,
            summary_file_path=args.save_to,
            findings_file_path=args.save_findings_to,
            called_dir=called_dir,
        )


if __name__ == "__main__":
    main()
