#! /usr/bin/env python3
#
# Compare the benchmarks run for two different versions of semgrep.
# It's meant to run within a semgrep pull request.
# This posts a GitHub PR comment if a performance anomaly is detected.
# It fails if the anomaly is really big.
#
# Testing: run the companion 'test-compare-perf' script.
#
import json
import os
import sys
from collections import namedtuple
from pathlib import Path
from typing import List

import requests


def send_comment(message: str, github_token: str, pull_request_number: str) -> None:
    """

    Assumes commithash is in GITHUB_EVENT_PATH file at obj["pull_request"]["head"]["sha"]

    If pull_request_number is "" then does nothing
    """
    if pull_request_number == "":
        return

    session = requests.Session()
    session.headers["Authorization"] = f"Token {github_token}"
    url = f"https://api.github.com/repos/returntocorp/semgrep/issues/{pull_request_number}/comments"
    session.headers["Accept"] = "application/vnd.github.comfort-fade-preview+json"

    commit_id = json.loads(Path(os.getenv("GITHUB_EVENT_PATH")).read_text())[
        "pull_request"
    ]["head"]["sha"]
    pr_comment_payload = {
        "body": message,
        "commit_id": commit_id,
    }
    r = session.post(url, json=pr_comment_payload, timeout=10)

    # Raise if unable to comment
    print(message)
    r.raise_for_status()


Timing = namedtuple("Timing", "name time")


def read_timing(filename: Path) -> List[Timing]:
    print(f"Reading {filename}")
    with open(filename) as f:
        return [Timing(a["name"], a["time"]) for a in json.load(f)]


def select_min_timings(al: List[Timing], bl: List[Timing]) -> List[Timing]:
    res = []
    for a, b in zip(al, bl):
        if a.time < b.time:
            res.append(a)
        else:
            res.append(b)
    return res


def main() -> None:
    # There are two versions of each timing to improve stability
    baseline_timing_file_1 = Path(sys.argv[1])
    baseline_timing_file_2 = Path(sys.argv[2])
    latest_timing_file_1 = Path(sys.argv[3])
    latest_timing_file_2 = Path(sys.argv[4])
    github_token = sys.argv[5]

    # Note this only defined in pull requests
    pull_request_number = sys.argv[6] if len(sys.argv) >= 7 else ""

    # Eliminate "flukes" by taking the fastest time available for each
    # benchmark that ran twice.
    baseline_times = select_min_timings(
        read_timing(baseline_timing_file_1), read_timing(baseline_timing_file_2)
    )
    latest_times = select_min_timings(
        read_timing(latest_timing_file_1), read_timing(latest_timing_file_2)
    )

    # Accumulators
    n = 0
    total_baseline = 0.0
    total_latest = 0.0
    total_rel_dur = 0.0
    min_rel_dur = 10000.0
    max_rel_dur = 0.0

    # Accumulators of warning messages and error messages in the format
    # expected by GitHub comments.
    messages = []
    errors = 0

    for baseline, latest in zip(baseline_times, latest_times):
        name = baseline.name
        baseline_time = baseline.time
        latest_time = latest.time
        n += 1
        total_baseline += baseline_time
        total_latest += latest_time

        rel_dur = latest_time / baseline_time
        diff = latest_time - baseline_time
        total_rel_dur += rel_dur
        min_rel_dur = min(min_rel_dur, rel_dur)
        max_rel_dur = max(max_rel_dur, rel_dur)

        # Our benchmarks are small and we run them in GitHub-hosted runners,
        # so we cannot distinguish between small perf deviations and noise.
        # Without this filter, the warnings fire all the time with a very high
        # false positive rate.
        probably_noise = (
            0.7 <= rel_dur and rel_dur <= 1.3 and abs(latest_time - baseline_time) < 2
        )

        print(
            f"[{name}] {rel_dur:.3f}x "
            f"Baseline: {baseline_time:.3f}, Latest: {latest_time:.3f}"
        )

        perc = 100 * (rel_dur - 1)

        # Assert latest time is not more than 30% slower than baseline.
        if rel_dur > 1.3:
            errors += 1
            messages.append(
                f"🚫 Benchmark {name} is too slow: " f"+{perc:.1f}% (+{diff:.3f} s)"
            )
        elif rel_dur > 1.2 and not probably_noise:
            messages.append(
                f"⚠️  Potential non-blocking slowdown in benchmark {name}: "
                f"+{perc:.1f}% (+{diff:.3f} s)"
            )
        elif rel_dur < 0.8 and not probably_noise:
            messages.append(
                f"🔥 Potential speedup in benchmark {name}: "
                f"{perc:.1f}% ({diff:.3f} s)"
            )

    mean_rel_dur = total_rel_dur / n
    mean_perc = 100 * (mean_rel_dur - 1)
    if mean_perc > 0:
        mean_perc_str = f"{mean_perc:.1f}% slower"
    else:
        mean_perc_str = f"{-mean_perc:.1f}% faster"

    if mean_rel_dur > 1.15:
        errors += 1
        messages.append(
            f"🚫 The whole benchmark suite is too slow: "
            f"+{mean_perc:.1f}% (+{mean_rel_dur:.3f} s)"
        )

    print(
        f"Average: {mean_rel_dur:.3f}x, "
        f"Min: {min_rel_dur:.3f}x, "
        f"Max: {max_rel_dur:.3f}x"
    )
    print(f"Total Baseline: {total_baseline:.3f} s, Latest: {total_latest:.3f} s")

    # Send PR comment if anything's weird or really wrong
    if messages:
        messages.append(f"{n} benchmarks, {mean_perc_str} on average.")
        # TODO: link to the job output. It seems not possible at the moment
        # without using the GitHub API.
        # See https://github.community/t/github-actions-url-to-run/16029
        messages.append(
            "Individual deviations greater than 20% from the baseline are reported. "
            "An individual performance degradation of over 30% or "
            "a global degradation of over 15% is an error and will block "
            "the pull request. "
            "See run output for full results "
            "('Show all checks' > 'Tests / semgrep benchmark tests' 'Details')."
        )
        msg = "\n\n".join(messages)
        print(f"Sending warnings and errors as a PR comment:\n{msg}")
        send_comment(
            msg,
            github_token,
            pull_request_number,
        )

    # We no longer fail because the benchmarks are not reliable enough to be tied
    # to single PRs

    # Old: Fail only after printing and sending all messages
    # assert not errors

    # Assert the rules in aggregate are not more than 7% slower than baseline
    # assert mean_rel_dur < 1.07


if __name__ == "__main__":
    main()
