#!/usr/bin/python 

import sys
from collections import defaultdict

# Defines
NAME        = "NAME"
RETTYPE     = "RETTYPE"
ARGS_ARRAY  = "ARGS_ARRAY"
ARGS_STRING = "ARGS_STRING"

# Global variables
Functions = defaultdict()

Format    = defaultdict()

Format["int"]                = "%d"
Format["size_t"]             = "%d"
Format["double"]             = "%lf"
Format["float"]              = "%f"
Format["char"]               = "%c"
Format["short"]              = "%h"
Format["long"]               = "%ld"
Format["long long"]          = "%lld"
Format["long double"]        = "%lf"
Format["long int"]           = "%ld"
Format["long long int"]      = "%lld"
Format["unsigned long long"] = "%llu"
Format["unsigned"]           = "%u"
Format["ptrdiff_t"]          = "%p"

def Usage():
  print ""
  print "Error: Invalid arguments!"
  print ""
  print "SYNTAX:"
  print "  wrapgen <API> <INTERFACES-FILE> <EVENT_BASE_NUMBER>"
  print ""

def write_header(fd, rcsid):
  fd.write( "/*****************************************************************************\\\n" )
  fd.write( " *                        ANALYSIS PERFORMANCE TOOLS                         *\n" )
  fd.write( " *                                   Extrae                                  *\n" )
  fd.write( " *              Instrumentation package for parallel applications            *\n" )
  fd.write( " *****************************************************************************\n" )
  fd.write( " *     ___     This library is free software; you can redistribute it and/or *\n" )
  fd.write( " *    /  __         modify it under the terms of the GNU LGPL as published   *\n" )
  fd.write( " *   /  /  _____    by the Free Software Foundation; either version 2.1      *\n" )
  fd.write( " *  /  /  /     \   of the License, or (at your option) any later version.   *\n" )
  fd.write( " * (  (  ( B S C )                                                           *\n" )
  fd.write( " *  \  \  \_____/   This library is distributed in hope that it will be      *\n" )
  fd.write( " *   \  \__         useful but WITHOUT ANY WARRANTY; without even the        *\n" )
  fd.write( " *    \___          implied warranty of MERCHANTABILITY or FITNESS FOR A     *\n" )
  fd.write( " *                  PARTICULAR PURPOSE. See the GNU LGPL for more details.   *\n" )
  fd.write( " *                                                                           *\n" )
  fd.write( " * You should have received a copy of the GNU Lesser General Public License  *\n" )
  fd.write( " * along with this library; if not, write to the Free Software Foundation,   *\n" )
  fd.write( " * Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA          *\n" )
  fd.write( " * The GNU LEsser General Public License is contained in the file COPYING.   *\n" )
  fd.write( " *                                 ---------                                 *\n" )
  fd.write( " *   Barcelona Supercomputing Center - Centro Nacional de Supercomputacion   *\n" )
  fd.write( "\*****************************************************************************/\n" )
  fd.write( "\n" )
  fd.write( "/* -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- *\\\n" )
  fd.write( " | @file: $File$\n" )
  fd.write( " | @last_commit: $Date$\n" )
  fd.write( " | @version: $Revision$\n" )
  fd.write( "\* -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- */\n\n" )
  fd.write( "#include \"common.h\"\n\n" )
  if (rcsid):
    fd.write( "static char UNUSED rcsid[] = \"$Id$\";\n\n" )

def parse_type_name(declaration):
  idx_spa = idx_ptr = idx_ref = -1

  if ' ' in declaration:
    idx_spa = declaration.rindex(' ')
  if '*' in declaration:
    idx_ptr = declaration.rindex('*')
  if '$' in declaration:
    idx_ref = declaration.rindex('&')

  max_idx = max( idx_spa, idx_ptr, idx_ref )

  type = declaration[0:max_idx+1].strip()
  name = declaration[max_idx+1:].strip()

  return (type, name)

def parse_parameters(declaration):
  args = declaration.split(',')
  parsed_args = []
  for current_arg in args:
    if (current_arg == "void"):
      continue
    (arg_type, arg_name) = parse_type_name(current_arg)
    parsed_args.append( (arg_type, arg_name) )

  return parsed_args

def get_format_for_type(type):
  if "*" in type or "&" in type:
    return "%p"
  else:
    if type in Format.keys():
      return Format[type]
    else:
      print "ERROR: Unknown type '"+type+"'!"
      print "Edit the script 'wrapgen' and add a new entry to the hash 'Format' to translate this type, and run the script again"
      sys.exit(0)

######################################################################
#                               MAIN                                 #
######################################################################

# Parse arguments
argc = len(sys.argv)

if (argc != 4):
  Usage()
  sys.exit(-1)

API        = sys.argv[1]
Interfaces = sys.argv[2]
EvBaseNo   = int(sys.argv[3])

fd_wrappers_c = open(API.lower()+'_wrappers.c', 'w')
fd_wrappers_h = open(API.lower()+'_wrappers.h', 'w')
fd_probes_c   = open(API.lower()+'_probes.c', 'w')
fd_probes_h   = open(API.lower()+'_probes.h', 'w')
fd_events_c   = open(API.lower()+'_events.c', 'w')
fd_events_h   = open(API.lower()+'_events.h', 'w')

######################################################################
#                        GENERATE WRAPPERS                           #
######################################################################

fd_interfaces = open( Interfaces, "r" )
all_functions = fd_interfaces.readlines()

fd_wrappers_c.write("#define _GNU_SOURCE\n\n")

write_header( fd_wrappers_c, True )

fd_wrappers_c.write("#include <dlfcn.h>\n")
fd_wrappers_c.write("#include <stdio.h>\n")
fd_wrappers_c.write("#include <stdlib.h>\n")
fd_wrappers_c.write("#include \"wrapper.h\"\n")
fd_wrappers_c.write("#include \""+API.lower()+"_wrappers.h\"\n")
fd_wrappers_c.write("#include \""+API.lower()+"_probes.h\"\n")

fd_wrappers_c.write("\n")

fd_wrappers_c.write("/****************************************\\\n")
fd_wrappers_c.write(" ***     POINTERS TO REAL SYMBOLS     ***\n")
fd_wrappers_c.write("\****************************************/\n\n")

ID = 0
for declaration in all_functions:
  if declaration.strip():
    if (declaration[0] == "#"):
      continue
    tokens           = declaration.split('(')
    rettype_and_name = tokens[0].strip()
    parameters       = tokens[1].split(')')[0].strip()
    (rettype, name)  = parse_type_name( rettype_and_name )
    parameters_array = parse_parameters( parameters )

    Functions[ID] = defaultdict()
    Functions[ID][NAME]        = name
    Functions[ID][RETTYPE]     = rettype
    Functions[ID][ARGS_ARRAY]  = parameters_array
    Functions[ID][ARGS_STRING] = parameters

    fd_wrappers_c.write( "static " + rettype + " (*" + name + "_real) (" + parameters + ") = NULL;\n" )
    ID = ID + 1
fd_wrappers_c.write("\n")

fd_wrappers_c.write("/****************************************\\\n")
fd_wrappers_c.write(" ***          SYMBOLS HOOK-UP         ***\n")
fd_wrappers_c.write("\****************************************/\n\n")

fd_wrappers_c.write( "static void Get_" + API + "_Hook_Points (int rank)\n{\n" )
fd_wrappers_c.write( "#if defined(PIC)\n" )
fd_wrappers_c.write( "# if defined(__APPLE__)\n" )
fd_wrappers_c.write( "# error \"Search for the appropriate library to load the symbols dynamically\"\n" )
fd_wrappers_c.write( "# else\n" )
fd_wrappers_c.write( "  void *lib = RTLD_NEXT;\n" )
fd_wrappers_c.write( "# endif /* __APPLE__ */\n\n" )

for i in range(len(Functions)):
  name    = Functions[i][NAME]
  rettype = Functions[i][RETTYPE]
  args    = Functions[i][ARGS_ARRAY]
  all_arg_types_array = []
  for (arg_type, arg_name) in args:
    all_arg_types_array.append( arg_type )
  all_arg_types_str = ', '.join([str(x) for x in all_arg_types_array])
  if (len(all_arg_types_str) <= 0):
    all_arg_types_str = "void"
    
  fd_wrappers_c.write( "  /* Obtain @ for " + name + " */\n" )
  fd_wrappers_c.write( "  " + name + "_real = \n" )
  fd_wrappers_c.write( "    (" + rettype + " (*)(" + all_arg_types_str + "))\n" )
  fd_wrappers_c.write( "    dlsym( lib, \"" + name + "\" );\n" )
  fd_wrappers_c.write( "  if (" + name + "_real == NULL && rank == 0)\n" )
  fd_wrappers_c.write( "    fprintf(stderr, PACKAGE_NAME\": Unable to find " + name + " in DSOs!!\\n\");\n\n" )

fd_wrappers_c.write( "#endif /* PIC */\n" )
fd_wrappers_c.write( "}\n\n" )

fd_wrappers_c.write("/****************************************\\\n")
fd_wrappers_c.write(" ***           INJECTED CODE          ***\n")
fd_wrappers_c.write("\****************************************/\n\n")

write_header( fd_wrappers_h, False )
fd_wrappers_h.write("#ifndef __"+API.upper()+"__WRAPPERS_H__\n")
fd_wrappers_h.write("#define __"+API.upper()+"__WRAPPERS_H__\n\n")

for i in range(len(Functions)):
  name    = Functions[i][NAME]
  rettype = Functions[i][RETTYPE]
  args_array  = Functions[i][ARGS_ARRAY]
  args_string = Functions[i][ARGS_STRING]

  all_arg_formats_array = []
  all_arg_names_array   = []
  for (arg_type, arg_name) in args_array:
    all_arg_names_array.append( arg_name )
    all_arg_formats_array.append( get_format_for_type( arg_type ) )
  all_arg_names_str = ', '.join([str(x) for x in all_arg_names_array])
  all_arg_formats_str = ' '.join([str(x) for x in all_arg_formats_array])

  fd_wrappers_h.write( rettype + " " + name + " (" + args_string + ");\n" )
  fd_wrappers_c.write( rettype + " " + name + " (" + args_string + ")\n{\n" )
  if (rettype != "void"):
    fd_wrappers_c.write( "  "+rettype+" res = 0;\n\n" );
  fd_wrappers_c.write( "#if defined(DEBUG)\n" )
  if (len(all_arg_formats_str) > 0):
    fd_wrappers_c.write( "  fprintf(stderr, PACKAGE_NAME\": DEBUG: "+all_arg_formats_str+"\\n\", "+all_arg_names_str+");\n" )
  fd_wrappers_c.write( "  fprintf(stderr, PACKAGE_NAME\": DEBUG: "+name+"_real at %p\\n\", " +name+ "_real);\n" )
  fd_wrappers_c.write( "#endif\n\n" )
  fd_wrappers_c.write( "  if ("+name+"_real != NULL && EXTRAE_ON())\n" )
  fd_wrappers_c.write( "  {\n" )
  fd_wrappers_c.write( "    Backend_Enter_Instrumentation(2);\n" )
  fd_wrappers_c.write( "    PROBE_"+name+"_ENTRY("+all_arg_names_str+");\n" )
  fd_wrappers_c.write( "    " )
  if (rettype != "void"):
    fd_wrappers_c.write( "res = " )
  fd_wrappers_c.write( name+"_real("+all_arg_names_str+");\n" )
  fd_wrappers_c.write( "    PROBE_"+name+"_EXIT();\n" )
  fd_wrappers_c.write( "    Backend_Leave_Instrumentation();\n" )
  fd_wrappers_c.write( "  }\n" )
  fd_wrappers_c.write( "  else if ("+name+"_real != NULL && !EXTRAE_ON())\n" )
  fd_wrappers_c.write( "  {\n" )
  fd_wrappers_c.write( "    " )
  if (rettype != "void"):
    fd_wrappers_c.write( "res = " )
  fd_wrappers_c.write( name+"_real("+all_arg_names_str+");\n" )
  fd_wrappers_c.write( "  }\n" )
  fd_wrappers_c.write( "  else\n" )
  fd_wrappers_c.write( "  {\n" )
  fd_wrappers_c.write( "    fprintf(stderr, PACKAGE_NAME\": Error "+name+" was not hooked!\\n\");\n" )
  fd_wrappers_c.write( "    exit(-1);\n" )
  fd_wrappers_c.write( "  }\n" )
  if (rettype != "void"):
    fd_wrappers_c.write( "  return res;\n" )
  fd_wrappers_c.write( "}\n" )
  fd_wrappers_c.write( "\n" )

fd_wrappers_c.write("/****************************************\\\n")
fd_wrappers_c.write(" ***       MODULE INITIALIZATION      ***\n")
fd_wrappers_c.write("\****************************************/\n\n")

fd_wrappers_h.write( "\nvoid Extrae_"+API+"_init(void);\n" )
fd_wrappers_c.write( "void __attribute__ ((constructor)) Extrae_"+API+"_init(void)\n" )
fd_wrappers_c.write( "{\n" )
fd_wrappers_c.write( "  Get_"+API+"_Hook_Points(0);\n" )
fd_wrappers_c.write( "}\n" )
fd_wrappers_c.write( "\n" )

fd_wrappers_h.write("\n#endif /* __"+API.upper()+"__WRAPPERS_H__ */\n")

fd_wrappers_c.close()
fd_wrappers_h.close()

######################################################################
#                         GENERATE PROBES                            #
######################################################################

write_header( fd_probes_c, True )
write_header( fd_probes_h, False )

fd_probes_c.write( "#include \""+API.lower()+"_events.h\"\n" )
fd_probes_c.write( "#include \""+API.lower()+"_probes.h\"\n" )
fd_probes_c.write( "#include \"wrapper.h\"\n\n" )

fd_probes_h.write("#ifndef __"+API.upper()+"__PROBES_H__\n")
fd_probes_h.write("#define __"+API.upper()+"__PROBES_H__\n\n")

for i in range(len(Functions)):
  name    = Functions[i][NAME]
  rettype = Functions[i][RETTYPE]
  args_array  = Functions[i][ARGS_ARRAY]
  args_string = Functions[i][ARGS_STRING]

  fd_probes_h.write( "void PROBE_"+name+"_ENTRY ("+args_string+");\n" )
  fd_probes_c.write( "void PROBE_"+name+"_ENTRY ("+args_string+")\n" )
  fd_probes_c.write( "{\n" )
  fd_probes_c.write( "  DEBUG_PROBES();\n" )
  fd_probes_c.write( "  if (EXTRAE_ON())\n" )
  fd_probes_c.write( "  {\n" )
  fd_probes_c.write( "    TRACE_EVENTANDCOUNTERS(LAST_READ_TIME, "+name.upper()+"_EV, EVT_BEGIN, EMPTY);\n" )
  fd_probes_c.write( "  }\n" )
  fd_probes_c.write( "}\n\n" )

  fd_probes_h.write( "void PROBE_"+name+"_EXIT (void);\n" )
  fd_probes_c.write( "void PROBE_"+name+"_EXIT (void)\n" )
  fd_probes_c.write( "{\n" )
  fd_probes_c.write( "  DEBUG_PROBES();\n" )
  fd_probes_c.write( "  if (EXTRAE_ON())\n" )
  fd_probes_c.write( "  {\n" )
  fd_probes_c.write( "    TRACE_EVENTANDCOUNTERS(TIME, "+name.upper()+"_EV, EVT_END, EMPTY);\n" )
  fd_probes_c.write( "  }\n" )
  fd_probes_c.write( "}\n\n" )

fd_probes_h.write("\n#endif /* __"+API.upper()+"__PROBES_H__ */\n")

fd_probes_c.close()
fd_probes_h.close()

######################################################################
#                         GENERATE EVENTS                            #
######################################################################

write_header( fd_events_h, False )

fd_events_h.write( "#ifndef __"+API.upper()+"_EVENTS_H__\n" )
fd_events_h.write( "#define __"+API.upper()+"_EVENTS_H__\n" )
fd_events_h.write( "\n" )
fd_events_h.write( "#include \"events.h\"\n" )
fd_events_h.write( "\n" )
fd_events_h.write( "#define "+API.upper()+"_EVENT_TYPE <PUT-THE-EVENT-"+API.upper()+"-TYPE-HERE>\n" )
fd_events_h.write( "\n" )
fd_events_h.write( "typedef enum {\n" )
for i in range(len(Functions)):
  name = Functions[i][NAME]
  fd_events_h.write( "  "+name.upper()+"_EV" )
  if (i < len(Functions) - 1):
    fd_events_h.write( "," )
  fd_events_h.write( "\n" )
fd_events_h.write( "} "+API.lower()+"_event_t;\n" )
fd_events_h.write( "\n" )
fd_events_h.write( "#endif /* __"+API.upper()+"_EVENTS_H__ */\n" )

write_header( fd_events_c, True )
fd_events_c.write( "#include \""+API.lower()+"_events.h\"\n" )
fd_events_c.write( "\n" )
fd_events_c.write( "char *"+API.lower()+"_events_labels["+str(len(Functions))+"] = {\n" )
for i in range(len(Functions)):
  name = Functions[i][NAME]
  fd_events_c.write( "  \""+name+"\"" )
  if (i < len(Functions) - 1):
    fd_events_c.write( "," )
  fd_events_c.write( "\n" )
fd_events_c.write( "};" )

fd_events_c.close()
