'''
Created on 03 Feb 2014

@author: Deon Marais (deon.marais@necsa.co.za)
@organization: The South African Nuclear Energy Corporation (Necsa) SOC Limited
@copyright: See copyright.txt

'''
from pyparsing import *
import os
import time
import itertools
import shutil
from copy import copy
import psutil
from subprocess import PIPE

    
#-------------------------------------------------------------------------------------------------------
def GetMaxLen(alist,idx):
    lenarr = 0
    for prm in alist: 
        if len(prm) > idx:
            lenarr = max(lenarr,len(str(prm[idx])))
    return lenarr

#-------------------------------------------------------------------------------------------------------
# mincolsizes is used in conjunction with arr to determine the minimum column width
def WritePrettyArray(fhandle, arr, mincolsizes = [0]):
    nrcols = 0
    for prm in arr: nrcols = max(nrcols, len(prm))
        
    colwidth = []
    for index in range(nrcols): 
        mxlen=GetMaxLen(arr,index)
        if len(mincolsizes) > index:
            if mincolsizes[index] > mxlen: 
                mxlen = mincolsizes[index]
        colwidth.append(mxlen)
        
    for prm in arr:
        for index in range(len(prm)):
            fhandle.write(str(prm[index]) + (colwidth[index]-len(str(prm[index])))*" " + "\t")
        fhandle.write("\n")
    fhandle.write("\n")
    fhandle.flush()

#-------------------------------------------------------------------------------------------------------
def GenPSBase(instrumentFile):
    """Takes a McStas instrument definition file name as input, parses the DEFINE INSTRUMENT section
    to generate the Parameter Study skeleton file"""
    
    instrumentFile=os.path.abspath(instrumentFile)
    f = open(instrumentFile,'r')
    instFileContent = f.read()
    f.close()
    startIdx=instFileContent.find("DEFINE INSTRUMENT")
    endIdx=instFileContent.find(")", startIdx)
    instFileDef = instFileContent[startIdx:endIdx+1]
    
    validch = "".join( [ c for c in printables if c != "," and c != ")" and c != "=" and c != "("] )   
    define_str = CaselessLiteral( "DEFINE" ).suppress()
    instrument_str = CaselessLiteral( "INSTRUMENT" ).suppress()
    instr_name = Word(validch)
    comment = ZeroOrMore(cStyleComment) + ZeroOrMore(dblSlashComment)
    lbrack, rbrack, comma, equal = map(Suppress, "(),=")
    parameter = Suppress(comment) + Group(Word(validch) + Suppress(comment) + \
        Optional(equal + Suppress(comment) + Word(validch) + Suppress(comment)) + ZeroOrMore(comma) + Suppress(comment) )
    
    paramsexpr = define_str + instrument_str + instr_name + lbrack + ZeroOrMore(parameter) + rbrack
    paramsparsed = paramsexpr.parseString(instFileDef)
    commentsparsed = comment.parseString(instFileContent[0:startIdx])
    
    outfilename = os.path.splitext(instrumentFile)[0] + ".parst"
    if os.path.exists(outfilename):
        outfilename = os.path.splitext(instrumentFile)[0] + "_" + time.strftime("%Y%m%d_%H%M%S") + ".parst"
    
    foutput=open(outfilename,"wb")
    foutput.write("/**************************************************************************************************\r\n")
    foutput.write(" *          McStas Parameter Study Base file\n")
    foutput.write(" *          Generated from " + instrumentFile + "\n")
    foutput.write(" *          Time: " +  time.strftime("%c") + "\n")
    foutput.write(" *          Instrument: " + paramsparsed[0] + "\n")
    foutput.write(" **************************************************************************************************/\r\n\r\n")
    for cmnt in commentsparsed:
        foutput.write(cmnt)
    
    foutput.write("\n\n")
    foutput.write("[VARIABLE_PARAMETERS]\n")
    WritePrettyArray(foutput, \
                     [["// Variable","StartValue","EndValue","NrSteps","[EndValue...]"], \
                      ["// --------","----------","--------","-------","-------------"]] + \
                      paramsparsed[1:])
        
    foutput.write("\n")
    foutput.write("[RUNTIME_PARAMETERS]\n") 
    expparams = [["McStas_perl_exe", 'C:/mcstas-2.0/bin/mcrun.pl'], \
                 ["Multiprocess_hosts", 'C:/mcstas-2.0/workspace/hosts'], \
                 ["Intrument_exe", os.path.splitext(instrumentFile)[0] + ".exe"], \
                 ["Number_nodes", 11], \
                 ["Neutron_count",10000], \
                 ["Random_seed",""], \
                 ["Output_dir",""], \
                 ["Plot_format", "PGPLOT","//Options are: PGPLOT, Gnuplot, Matlab, Scilab, HTML_VRML"], \
                 ["Execution_mode", "simulate", "//Options are: simulate, trace"], \
                 ["Clustering", "MPI", "//Options are: none, MPI, SSH"]]

    WritePrettyArray(foutput,expparams)
    
    foutput.write("\n")
    foutput.close()
    return outfilename
    
#-------------------------------------------------------------------------------------------------------
def GetSections(parst_file_path):
    f = open(parst_file_path,'r')
    parstFileContent = f.read()
    f.close()

    validch = Word(alphanums+"_"+".")
    validln = Word(alphanums+"_"+"."+" ")
    comment = ZeroOrMore(cStyleComment) + ZeroOrMore(dblSlashComment)
    lbrack, rbrack = map(Suppress, "[]")
    paramdef = Group(Group(~lbrack + validch) + restOfLine)
    # strip any leading or trailing blanks from key
    def stripKey(tokens):
        tokens[0][1] = tokens[0][1].strip()
        tokens[0][1] = tokens[0][1].split()
        tokens[0] = list(itertools.chain(*tokens[0]))
    paramdef.setParseAction(stripKey)
 
    sechead = Suppress(Literal("[")) + OneOrMore(validch) + Suppress(Literal("]"))
    secexpr = ZeroOrMore(Group(Optional(Suppress(comment)) + sechead + Optional(Suppress(comment)) + ZeroOrMore(paramdef)))
    seclist = secexpr.parseString(parstFileContent)
    print "Parameters parsed: ", seclist    
    return seclist

#-------------------------------------------------------------------------------------------------------
def GetStaticDynamic(seclist, section):
    varstatic = []
    vardyn = []
    vardynvals = []
    nrSteps = float(0)

    for prm in seclist[section['VARIABLE_PARAMETERS']][1:]:
        lenprm = len(prm)
        nrSteps = 1.0
        tmpdyn =[]
        if lenprm == 4: nrSteps = float(prm[3]) #The number of steps are specified
        
        if lenprm > 2 : 
            vardyn += [prm]
            if len(prm) > 4:
                if (prm[3] == '0') or (prm[3] == '-') :    #Don't calculate increments if NrSteps == 0, hard values are given
                    tmpdyn+=[prm[1]] + [prm[2]] + prm[4:]
             
            else: #Calculate the increments    
                incr = (float(prm[2])-float(prm[1]))/nrSteps
                for num in range(int(nrSteps)):
                    tmpdyn += [str((float(prm[1]) + num*incr))]
                tmpdyn += [prm[2]]
            vardynvals += [[prm[0]] + tmpdyn]
        elif lenprm == 2 :
            if prm[1] =='[]' :                          # Values for parameter are given in a separate section named by the [parameter]
                    vardynvals += [[prm[0]] + list(itertools.chain(*seclist[section[prm[0]]][1:]))]
            else :
                varstatic += [prm]

    print "Static parameters: ", varstatic
    print "Dynamic parameters: ", vardynvals
    totruns = float(1.0)
    for prm in vardynvals: totruns *= float(len(prm)-1)
    print "Total number of runs: " ,totruns
    return varstatic, vardynvals

#-------------------------------------------------------------------------------------------------------
def GetRuntime(prms,parst_file=''):
    runln =''
    section = {prms[x][0]: x for x in range(len(prms))}
    
    runln += ' "' + prms[section['McStas_perl_exe']][1] +'"'
    if prms[section['Multiprocess_hosts']][1] != 'none': runln += ' --machines="' + prms[section['Multiprocess_hosts']][1] +'"'
    runln += ' "' + prms[section['Intrument_exe']][1] + '"'
    if prms[section['Clustering']][1] == 'MPI':  runln += ' --mpi=' + prms[section['Number_nodes']][1]
    if prms[section['Clustering']][1] == 'SSH':  runln += ' --multi=' + prms[section['Number_nodes']][1]
    runln += ' --ncount=' + prms[section['Neutron_count']][1]
    if len(prms[section['Random_seed']]) > 1: runln += ' --seed=' + prms[section['Random_seed']][1]
    if prms[section['Plot_format']][1] == 'PGPLOT':  runln += ' --format=PGPLOT'
    if prms[section['Plot_format']][1] == 'Gnuplot':  runln += ' --format=PGPLOT'    
    
    dirbase=''
    if len(prms[section['Output_dir']]) > 1:  
        dirbase = prms[section['Output_dir']][1]
    else:
        dirbase = os.path.splitext(parst_file)[0]  + "_" + time.strftime("%Y%m%d_%H%M%S")   

    return runln,dirbase

#-------------------------------------------------------------------------------------------------------
def calccombination(alist,ret=[]):
    if len(alist)>1:
        ret = calccombination(alist[1:],ret)
    else:
        paramname = alist[0][0]
        paramvals = alist[0][1:]
        for val in paramvals:
            ret += [paramname + "=" + val]
        return ret
    
    paramname = alist[0][0]
    paramvals = alist[0][1:]
    retcpy = copy(ret)
    ret = []
    for val in paramvals:
        for retval in retcpy:
            ret += [paramname + "=" + val + " " + retval]
    return ret



#-------------------------------------------------------------------------------------------------------
def RunPerlMPIWait(exeln): 
    #proce = psutil.Popen("perl.exe " + exeln)
    proce = psutil.Popen("perl.exe " + exeln, stdout = None, stdin = None, stderr=None)
    time.sleep(10.0)
    while psutil.pid_exists(proce.pid):
        time.sleep(1)  
    pid = -1
    try:
        for proc in psutil.process_iter():
            if proc.name == "mpiexec.exe":
                pid = proc.pid
                break
    except:
        True
    while psutil.pid_exists(pid):
        time.sleep(1)  
        
#-------------------------------------------------------------------------------------------------------
def GenPSCmds(parst_file):
    from mylib import ProgressBar
    
    parst_file=os.path.abspath(parst_file)
    instr_file = os.path.splitext(parst_file)[0] + ".instr"
    exe_file = os.path.splitext(parst_file)[0] + ".exe"
    seclist = GetSections(parst_file)
    section = {seclist[x][0]: x for x in range(len(seclist))}     #Generate a dictionary with the section names referring to the index
    staticvals, dynvals = GetStaticDynamic(seclist, section)
    runln,outdirbase = GetRuntime(seclist[section['RUNTIME_PARAMETERS']][1:],parst_file)

    if os.path.exists(outdirbase) == False:  os.mkdir(outdirbase)
    outdirbase = os.path.abspath(outdirbase)
    outdirbase = outdirbase.replace("\\","/")
    shutil.copy(parst_file,outdirbase)
    if os.path.exists(instr_file): shutil.copy(instr_file,outdirbase)
    if os.path.exists(exe_file): shutil.copy(exe_file,outdirbase)
    #logfile = open(outdirbase + os.path.sep + os.path.basename(parst_file) + ".log", 'w+')
    logfile = open(outdirbase + "/" + os.path.basename(parst_file) + ".log", 'w+')
    #cmdsfile = open(outdirbase + os.path.sep + os.path.basename(parst_file) + ".cmds", 'a')
    headers = ["Run_number", "       Start        ", "         End        "] + [x[0] for x in dynvals]
    underscores = ['-'*len(x) for x in headers]
    colsizes = [len(x) for x in headers]
    WritePrettyArray(logfile,[headers, underscores], colsizes)
    logfile.seek(-2,1)      #Moves pointer back to remove the last newline added by WritePrettyArray
    
    statln = ''
    for varstatic in staticvals:
        statln += " " + varstatic[0] + "=" + varstatic[1]

    curdirnr = 1
    for root, dirs, files in os.walk(outdirbase):
        if root == outdirbase:
            for subdirname in dirs:
                if int(subdirname.split('_')[0]) > curdirnr: curdirnr = 1 + int(subdirname.split('_')[0])
    
    #cmdsfilename = outdirbase + os.path.sep + os.path.basename(parst_file) +"_"+str(curdirnr)+ ".cmds"
    cmdsfilename = outdirbase + "/" + os.path.basename(parst_file) +"_"+str(curdirnr)+ ".cmds"
    cmdsfile = open(cmdsfilename, 'a')
        
    i = 1    
  
    combinations=calccombination(dynvals,[])
    for dyn in combinations:
        #outdir = outdirbase + os.path.sep + str(curdirnr) + '_' + str(i)
        outdir = outdirbase + "/" + str(curdirnr) + '_' + str(i)
        exeln = runln + ' --dir="' + outdir + '" ' + statln + ' ' + dyn
        cmdsfile.write(exeln + "\n")
        i+=1
    
    cmdsfile.close()
    cmdsfile = open(cmdsfilename,'r')
    exelines = cmdsfile.readlines()
    cmdsfile.close()
    
    numcombinations = len(combinations)
    progress = ProgressBar("Running...", numcombinations)
    for i in range(numcombinations):
        starttime = time.strftime("%Y/%m/%d %H:%M:%S",time.localtime())
        combination = combinations[i]
        progress.setinfo(combination)
        dynvalues = []
        for pair in combination.split(" "):
            dynvalues.append(pair.split("=")[1])
        exeline = exelines[i]
        runnr = os.path.split(exeline[exeline.find("--dir"):].split('"')[1])[1]
        RunPerlMPIWait(exeline)
        endtime = time.strftime("%Y/%m/%d %H:%M:%S",time.localtime())
        logline = [runnr] + [starttime] + [endtime] + dynvalues 
        WritePrettyArray(logfile,[logline], colsizes)
        logfile.seek(-2,1)      #Moves pointer back to remove the last newline added by WritePrettyArray
        if progress.wasCanceled(): break
        progress.step()
        True
    logfile.close()
    return cmdsfile
    a=1


#-------------------------------------------------------------------------------------------------------
def RunPSCmds(cmds_file):
    from mylib import ProgressBar
    
    import time
    f = open(cmds_file, 'r')
    cmds = f.readlines()
    f.close()
    
    numcombinations = len(cmds)
    progress = ProgressBar("Running...", numcombinations)
    for exeln in cmds:
        RunPerlMPIWait(exeln)
        if progress.wasCanceled(): break
        progress.step()
    
    True

 


#-------------------------------------------------------------------------------------------------------
if __name__ == '__main__':
    #GenPSBase("D:/home/deon/Develop/McStas_Workspace/SAMTEST.instr")
    GenPSCmds("D:/home/deon/Develop/McStas_Workspace/SAMTEST.parst")
    pass