#!/usr/bin/env python
#=========================================================================
# sv-verilog [options] prompt-filename
#=========================================================================
#
#  -v --verbose        Display the prompt
#  -h --help           Display this message
#  -l --list-models    List supported models
#  -m --model          LLM model to use (default: gpt-3.5-turbo)
#  -t --temperature    LLM model temperature (default: 0.85)
#  -p --top-p          LLM model top_p (default: 0.95)
#  -n --max-tokens     LLM model max_tokens (default: 1024)
#  -e --explain        Let the model include an explanation
#  -x --examples       Include the in-context examples for a number of shots (default: 0)
#  -r --rules          Include the in-context rules
#     --task           Task to prompt (default: code-complete-iccad2023)
#     --output         File to write extracted code
#
# Use GPT to generate verilog from the given prompt.
#
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: MIT
# Author : Christopher Batten, NVIDIA and Nathaniel Pinckney, NVIDIA
#

import argparse
import sys
import os
import re
import time
from collections import namedtuple


from langchain_nvidia_ai_endpoints import ChatNVIDIA
from langchain_openai import ChatOpenAI
from langchain.schema   import SystemMessage, HumanMessage
from langchain.callbacks import get_openai_callback

#-------------------------------------------------------------------------
# Command line processing
#-------------------------------------------------------------------------

class ArgumentParserWithCustomError(argparse.ArgumentParser):
  def error( self, msg = "" ):
    if ( msg ): print("\n ERROR: %s" % msg)
    print("")
    file = open( sys.argv[0] )
    for ( lineno, line ) in enumerate( file ):
      if ( line[0] != '#' ): sys.exit(msg != "")
      if ( (lineno == 2) or (lineno >= 4) ): print( line[1:].rstrip("\n") )

def parse_cmdline():
  p = ArgumentParserWithCustomError( add_help=False )

  p.add_argument( "-h", "--help",        action="store_true" )
  p.add_argument( "-v", "--verbose",     action="store_true" )
  p.add_argument( "-l", "--list-models", action="store_true" )
  p.add_argument( "-m", "--model",       type=str,   default="gpt-3.5-turbo" )
  p.add_argument( "-t", "--temperature", type=float, default=0.8 )
  p.add_argument( "-n", "--max-tokens",  type=int,   default=1024 )
  p.add_argument( "-p", "--top-p",       type=float, default=0.95 )
  p.add_argument( "-e", "--explain",     action="store_true" )
  p.add_argument( "-x", "--examples",    type=int,   default=0 )
  p.add_argument( "-r", "--rules",       action="store_true" )
  p.add_argument(       "--output",      type=str,   default="-" )
  p.add_argument(       "--task",        type=str,   default="code-complete-iccad2023" )
  p.add_argument( "prompt_filename" )

  opts = p.parse_args()
  if opts.help: p.error()
  return opts

#-------------------------------------------------------------------------
# Models
#-------------------------------------------------------------------------

openai_models = [
  "gpt-3.5-turbo",
  "gpt-4",
  "gpt-4-turbo",
  "gpt-4o",
]

nim_chat_models = [
  "ai-llama2-70b",
  "ai-llama3-70b",
  "meta/llama-3.1-8b-instruct",
  "meta/llama-3.1-70b-instruct",
  "meta/llama-3.1-405b-instruct",
  "ai-codellama-70b",
  "ai-gemma-7b",
  "ai-codegemma-7b",
  "ai-mistral-7b-instruct-v2",
  "ai-mixtral-8x7b-instruct",
  "ai-mistral-large"
]

manual_models = [
  'manual-rtl-coder',
  'manual-deepseek-coder-6.7b',
  'manual-deepseek-coder-33b'
]

model_aliases = {
  "gpt3.5-turbo" : "gpt-3.5-turbo",
  "gpt4"         : "gpt-4",
  "gpt4-turbo"   : "gpt-4-turbo",
  "ai-llama-3_1-8b-instruct": "meta/llama-3.1-8b-instruct",
  "ai-llama-3_1-70b-instruct": "meta/llama-3.1-70b-instruct",
  "ai-llama-3_1-405b-instruct": "meta/llama-3.1-405b-instruct"
}

models_to_repeat_system_to_human_messages = [
  "ai-mistral-large"
]

available_tasks = [
  "code-complete-iccad2023",
  "spec-to-rtl"
]

#-------------------------------------------------------------------------
# VerboseOutput
#-------------------------------------------------------------------------

class VerboseOutput:

  def __init__( self, verbose ):
    self.verbose = verbose

  def print( self, string ):
    if self.verbose:
      print( string )

#-------------------------------------------------------------------------
# Context
#-------------------------------------------------------------------------

prompts = {}

prompts['code-complete-iccad2023'] = {
  'system_msg' : """
You only complete chats with syntax correct Verilog code. End the Verilog module code completion with 'endmodule'. Do not include module, input and output definitions.
""",
  'prompt_prefix' : """
// Implement the Verilog module based on the following description. Assume that signals are positive clock/clk triggered unless otherwise stated.
"""
}

prompts['spec-to-rtl'] = {
  'system_msg' : """
You are a Verilog RTL designer that only writes code using correct Verilog syntax.
""",
  'prompt_prefix' : ""
}

# # This is a test system message to see if system messages are working.
# iccad2023_system_msg="""
# Translate the following code to C++.
# """


prompt_rules_suffix="""
Here are some additional rules and coding conventions.

 - Declare all ports and signals as logic; do not to use wire or reg.

 - For combinational logic with an always block do not explicitly specify
   the sensitivity list; instead use always @(*).

 - All sized numeric constants must have a size greater than zero
   (e.g, 0'b0 is not a valid expression).

 - An always block must read at least one signal otherwise it will never
   be executed; use an assign statement instead of an always block in
   situations where there is no need to read any signals.

 - if the module uses a synchronous reset signal, this means the reset
   signal is sampled with respect to the clock. When implementing a
   synchronous reset signal, do not include posedge reset in the
   sensitivity list of any sequential always block.
"""

# prompt_no_explain_suffix="""
# Please do not include any explanations in your response.
# """

# prompt_no_explain_suffix="""
# Enclose your code with <CODE> and </CODE>. Only output the code snippet
# and do NOT output anything else.
# """

prompt_no_explain_suffix="""
Enclose your code with [BEGIN] and [DONE]. Only output the code snippet
and do NOT output anything else.
"""

#-------------------------------------------------------------------------
# Main
#-------------------------------------------------------------------------

def main():

  opts = parse_cmdline()

  # Check models

  if opts.list_models:
    print( "" )
    print( "OpenAI Models" )

    for model in openai_models:
      print( f" - {model}" )

    print( "" )
    print( "NIM Models" )

    for model in nim_chat_models:
      print( f" - {model}" )

    print( "" )
    print( "Manual Models" )

    for model in manual_models:
      print( f" - {model}" )


    print( "" )
    print( "Model Aliases" )

    for key,value in model_aliases.items():
      print( f" - {key} : {value}" )

    print( "" )
    return

  model = opts.model
  if opts.model in model_aliases:
    model = model_aliases[opts.model]

  if model not in openai_models + nim_chat_models + manual_models:
    print("")
    print(f"ERROR: Unknown model {model}")
    print("")
    return

  task = opts.task
  if task not in available_tasks:
    print("")
    print(f"ERROR: Unknown task {task}")
    print("")
    return


  system_msg = prompts[task]['system_msg']
  prompt_prefix = prompts[task]['prompt_prefix']

  # Check for an output file

  out = VerboseOutput( opts.verbose )

  # Log parameters

  problem = "?"
  if opts.prompt_filename.endswith("_prompt.txt"):
    problem = os.path.basename(opts.prompt_filename[:-11])

  temperature = opts.temperature
  top_p       = opts.top_p
  max_tokens  = opts.max_tokens

  out.print( "" )
  out.print( f"problem     = {problem}"     )
  out.print( f"model       = {model}"       )
  out.print( f"temperature = {temperature}" )
  out.print( f"top_p       = {top_p}"       )
  out.print( f"max_tokens  = {max_tokens}"  )

  # Read the prompt file
  with open(opts.prompt_filename) as file:
    prompt = file.read()

  # Create full prompt
  full_prompt = ""
  if model in models_to_repeat_system_to_human_messages:
    full_prompt = system_msg + "\n"
  if opts.examples != 0:
    assert opts.examples > 0
    prompt_example_prefix_filename=os.path.dirname(__file__) + f"/verilog-example-prefix_{task}_{str(opts.examples)}-shot.txt"
    with open(prompt_example_prefix_filename) as f:
      prompt_example_prefix = f.read()
    full_prompt += prompt_example_prefix

  if task == "code-complete-iccad2023":
    prefix = True
    prefixed_prompt = []

    for line in prompt.splitlines():
      if "module TopModule" in line:
        prefixed_prompt.append("")
        prefix = False

      if prefix:
        prefixed_prompt.append("// " + line)
      else:
        prefixed_prompt.append(line)

    prefixed_prompt = "\n".join(prefixed_prompt)

    full_prompt += prompt_prefix + prefixed_prompt

  elif task == "spec-to-rtl":
    # New style with "Question/Answer"
      # TODO: Old prompt style, resolve
      # full_prompt += "### Problem \n"

    # full_prompt += prompt.rstrip() + "\n"
    full_prompt += "\nQuestion:\n"
    full_prompt += prompt.strip() + "\n"

    if opts.rules:
      full_prompt += prompt_rules_suffix

    if not opts.explain:
      full_prompt = full_prompt.rstrip() + "\n" + prompt_no_explain_suffix

    # TODO: Old prompt style, resolve
    # if opts.examples:
    #   full_prompt += "\n### Solution \n"
    full_prompt += "\nAnswer:\n"

  # Print system message and prompt

  out.print("")
  out.print("System Message")
  out.print("-"*74)
  out.print(system_msg)

  out.print("Prompt")
  out.print("-"*74)
  out.print(full_prompt.rstrip())

  file = open( opts.output, 'w' )
  print( "", file=file )

  # Create LLM messages

  msgs = [ SystemMessage(system_msg), HumanMessage(full_prompt) ]

  # Query the LLM

  if model in openai_models:
    llm = ChatOpenAI(
      model       = model,
      temperature = temperature,
      top_p       = top_p,
      max_tokens  = max_tokens,
    )
  elif model in nim_chat_models:
    llm = ChatNVIDIA(
      model       = model,
      temperature = temperature,
      top_p       = top_p,
      max_tokens  = max_tokens,
    )
  elif model in manual_models:
    output_prompt_filename, junk_ext = os.path.splitext(opts.output)
    with open( output_prompt_filename + "_fullprompt.txt", 'w' ) as output_prompt_file:
      print( full_prompt.rstrip(), file=output_prompt_file )
    with open( output_prompt_filename + "_systemprompt.txt", 'w' ) as output_prompt_file:
      print( system_msg, file=output_prompt_file )
  else:
    # should never reach here
    return

  if model not in manual_models:
    for _ in range(10):
      try:
        with get_openai_callback() as cb:
          resp = llm.invoke(msgs)
          break
      except Exception as e:
        print("")
        print("ERROR: LLM query failed, retrying in 20 seconds")
        print(f"{type(e)}")
        print(f"{e}")
        print("")
        time.sleep(20)
  else:
    output_prompt_filename, junk_ext = os.path.splitext(opts.output)
    input_response_filename = output_prompt_filename + "_response.txt"
    # if exists
    if os.path.exists(input_response_filename):
      with open( input_response_filename) as input_response_file:
        resp = namedtuple('Response', ['content'])(content=input_response_file.read())
    else:
      resp = namedtuple('Response', ['content'])(content="// Manually run")
    cb = namedtuple('Callback', ['prompt_tokens', 'resp_tokens', 'total_tokens', 'total_cost', 'completion_tokens'])(prompt_tokens=0, resp_tokens=0, total_tokens=0, total_cost=0, completion_tokens=0)

  # Display the response

  out.print("")
  out.print("Response")
  out.print("-"*74)

  out.print("")
  print(resp.content)
  out.print("")

  # Display statistics

  out.print("Statistics")
  out.print("-"*74)

  out.print("")
  out.print(f"prompt_tokens = {cb.prompt_tokens}")
  out.print(f"resp_tokens   = {cb.completion_tokens}")
  out.print(f"total_tokens  = {cb.total_tokens}")
  out.print(f"total_cost    = {cb.total_cost}")
  out.print("")


  # Extract code from response
  file = open( opts.output, 'w' )
  print( "", file=file )

  if task == "code-complete-iccad2023":
    # First pass to find backticks and if module exists
    backticks_count  = 0
    endmodule_before_startmodule = False
    module_already_exists = False
    for line in iter(resp.content.splitlines()):
      if line.startswith("```"):
        backticks_count += 1
      elif line.startswith("endmodule"):
        if not module_already_exists:
          endmodule_before_startmodule = True
      elif line.startswith("module TopModule"):
        module_already_exists = True

    if endmodule_before_startmodule:
      module_already_exists = False

    # if module doesn't exist (which it shouldn't for code completition) then print out interface
    if not module_already_exists:
      if opts.prompt_filename.endswith("_prompt.txt"):
        ifc_filename = opts.prompt_filename[:-11] + "_ifc.txt"
      else:
        print("Error: prompt filename doesn't end with _prompt!")
        sys.exit(-1)
      with open(ifc_filename) as ifc_file:
        print(ifc_file.read() + "\n", file=file)

    # Second pass print out as appropriate
    found_first_backticks  = False
    found_second_backticks = False
    found_module           = False
    found_endmodule        = False

    for line in iter(resp.content.splitlines()):
      echo_line = True

      if line.strip().startswith("module TopModule"): # now we monitor if we've found it but we don't do anything with it.
        found_module = True

      if backticks_count >= 2:
        if (not found_first_backticks) or found_second_backticks:
          echo_line = False
      else:
        if found_endmodule:
          echo_line = False
        if module_already_exists and not found_module:
          echo_line = False

      if line.startswith("```"):
        if not found_first_backticks:
          found_first_backticks = True
        else:
          found_second_backticks = True
        echo_line = False
      elif line.strip().startswith("endmodule"):
        found_endmodule = True

      if echo_line:
        if model == "manual-rtl-coder":
          if line.strip().startswith("endmodule"):
            line = line.rsplit('endmodule', 1)[0] + "\n" + "endmodule"
        print( line, file=file )
      
    print( "", file=file )

    if backticks_count == 1 or backticks_count > 2:
      comment_delim = "//"
      print( comment_delim + " VERILOG-EVAL: abnormal backticks count", file=file )
      print( "", file=file )
    if found_module:
      comment_delim = "//"
      print( comment_delim + " VERILOG-EVAL: errant inclusion of module definition", file=file )
      print( "", file=file )
    if not found_endmodule:
      comment_delim = "//"
      print( comment_delim + " VERILOG-EVAL: endmodule not found", file=file )
      print( "", file=file )

  elif task == "spec-to-rtl":
    # First pass to find backticks and if module exists
    backticks_count  = 0
    endmodule_before_startmodule = False
    module_already_exists = False
    for line in iter(resp.content.splitlines()):
      if line.startswith("```"):
        backticks_count += 1
      elif line.startswith("endmodule"):
        if not module_already_exists:
          endmodule_before_startmodule = True
      elif line.startswith("module TopModule"):
        module_already_exists = True

    if endmodule_before_startmodule:
      module_already_exists = False

    # Old: Scan response for code using <CODE></CODE>
    # New: Scan response for code using [BEGIN][DONE]

    found_code_lines = []
    found_code_start = False
    found_code_end   = False

    for line in iter(resp.content.splitlines()):

      # if not found_code_start:
      #   if line.strip() == "<CODE>":
      #     found_code_start = True
      #   elif line.lstrip().startswith("<CODE>"):
      #     found_code_lines.append( line.lstrip().replace("<CODE>","") )
      #     found_code_start = True

      # elif found_code_start and not found_code_end:
      #   if line.strip() == "</CODE>":
      #     found_code_end = True
      #   elif line.rstrip().endswith("</CODE>"):
      #     found_code_lines.append( line.rstrip().replace("</CODE>","") )
      #     found_code_end = True
      #   else:
      #     found_code_lines.append( line )

      if not found_code_start:
        if line.strip() == "[BEGIN]":
          found_code_start = True
        elif line.lstrip().startswith("[BEGIN]"):
          found_code_lines.append( line.lstrip().replace("[BEGIN]","") )
          found_code_start = True

      elif found_code_start and not found_code_end:
        if line.strip() == "[DONE]":
          found_code_end = True
        elif line.rstrip().endswith("[DONE]"):
          found_code_lines.append( line.rstrip().replace("[DONE]","") )
          found_code_end = True
        else:
          found_code_lines.append( line )


    if found_code_start and found_code_end:
      for line in found_code_lines:
        print( line, file=file )

    # If did not find code by looking for <CODE></CODE>, then scan response
    # for code using backticks

    if not found_code_start and not found_code_end:
      # fallback to code completion style extraction

      # Second pass print out as appropriate
      found_first_backticks  = False
      found_second_backticks = False
      found_module           = False
      found_endmodule        = False

      for line in iter(resp.content.splitlines()):
        echo_line = True

        if line.strip().startswith("module TopModule"): # now we monitor if we've found it but we don't do anything with it.
          found_module = True

        if backticks_count >= 2:
          if (not found_first_backticks) or found_second_backticks:
            echo_line = False
        else:
          if found_endmodule:
            echo_line = False
          if module_already_exists and not found_module:
            echo_line = False

        if line.startswith("```"):
          if not found_first_backticks:
            found_first_backticks = True
          else:
            found_second_backticks = True
          echo_line = False
        elif line.strip().startswith("endmodule"):
          found_endmodule = True

        if echo_line:
          if model == "manual-rtl-coder":
            if line.strip().startswith("endmodule"):
              line = line.rsplit('endmodule', 1)[0] + "\n" + "endmodule"
          print( line, file=file )
        
      print( "", file=file )

      # found_code_lines = []
      # found_code_start = False
      # found_code_end   = False

      # for line in iter(resp.content.splitlines()):

      #   if not found_code_start:
      #     if line.lstrip().startswith("```"):
      #       found_code_start = True

      #   elif found_code_start and not found_code_end:
      #     if line.rstrip().endswith("```"):
      #       found_code_end = True
      #     else:
      #       found_code_lines.append( line )

      # if found_code_start and found_code_end:
      #   for line in found_code_lines:
      #     print(line, file=file )

      #   # Print comment so we can track responses that did not use
      #   # <CODE></CODE> correctly

      #   print( "", file=file )
        
      comment_delim = "//"
      print( comment_delim + " VERILOG-EVAL: response did not use <CODE></CODE> correctly", file=file )


  print( "", file=file )

  file.close()

main()
