#!/usr/bin/python
# -*- coding: utf-8 -*- 

    # Copyright (C) 2010–2015 Agnieszka Patejuk

    # This program is free software: you can redistribute it and/or modify
    # it under the terms of the GNU General Public License as published by
    # the Free Software Foundation, either version 3 of the License, or
    # (at your option) any later version.

    # This program is distributed in the hope that it will be useful,
    # but WITHOUT ANY WARRANTY; without even the implied warranty of
    # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    # GNU General Public License for more details.

    # You should have received a copy of the GNU General Public License
    # along with this program.  If not, see <http://www.gnu.org/licenses/>.

import sys
import os
from common2xle import *

nkjp = sys.argv[1]
testsuite = sys.argv[2]
outpath = sys.argv[3]
badinterp = open(testsuite+'-badinterp', 'w')
error = open(testsuite+'-xml-error', 'w')

def findPathsPattern(rootpath, pattern):
	inputFiles = []
	for root, dirs, files in os.walk(rootpath):
		for name in files:
			if re.match(pattern, name):
				inputFiles.append(os.path.join(root,name))
	return inputFiles

nkjp_paths = findPathsPattern(nkjp, "ann_morphosyntax.xml")

def gstr(line):
	return line[8:-9]

def gctag(line):
	return line[15:-3]

def gmsd(line):
	parts = line.strip()[1:-2].split()
	for part in parts:
		if part[:6] == 'xml:id':
			msdid = part[8:-1]
		if part[:5] == 'value':
			msdtag = part[7:-1]	
	return [msdtag, msdid]

def gdsmb(line):
	return line.split()[1][7:-1]

def getNKJPmorfsyn(path):
	file = open(path, 'r')
	inside_par = False
	inside_sent = False
	inside_segm = False
	inside_orth = False
	inside_interps = False
	inside_disamb = False
	all_interps = []
	# print '\n'+path
	for line in file:
		# print line
		# line = line.strip()
		line = line.strip().decode("utf-8")
		if line[:3] == '<p ':
			par = line.split()[-1][8:-2]
			inside_par = True
			# print par
		if inside_par:
			if line[:3] == '<s ':
				sent = line.split()[-1][8:-2]
				inside_sent = True
				tokens = []
				# print sent
				# reset all other sentence-level values
				inside_segm = False
				inside_orth = False
				inside_interps = False
				inside_disamb = False
			if inside_sent:
				if line[:5] == '<seg ':
					inside_segm = True
				if inside_segm:
					if line == '</seg>':
						inside_segm = False
					if line == '<f name="orth">':
						inside_orth = True
						localinterp = []
					if inside_orth:
						if re.match('<string>.*</string>', line):
							# orth
							# localinterp.append(line[8:-9])
							# localinterp.append(gstr(line))
							# print 'ORTH: '+gstr(line)
							orth = gstr(line)
							localinterp.append(orth)
							# print 'ORTH: '+orth
							# print 'orth: '+line[8:-9]
						if line == '</f>':
							inside_orth = False
					if line == '<f name="interps">':
						inside_interps = True
						idct = {}
						# print 'starting interpdict'
						inside_sublex = False
					if inside_interps:
						# if line[:14] == '<fs type="lex"':
						if re.search('<fs .*type="lex"', line):
							inside_sublex = True
							inside_sublex_base = False
							inside_sublex_ctag = False
							inside_sublex_msd = False
						if inside_sublex:
							if line == '<f name="base">':
								inside_sublex_base = True
								# to avoid crashes when xml is malformed
								lxbase = ""
							if inside_sublex_base:
								if re.match('<string>.*</string>', line):
									lxbase = gstr(line)
									# print 'BASE: '+lxbase
								if line == '</f>':
									inside_sublex_base = False
									if lxbase == '':
										# print 'lxbase problem' 
										orth = 'some_problem_with_encoding'
										error.write('base error: '+path+'\t'+par+'\t'+sent+'\t'+orth.encode("utf-8")+'\n')
							if line == '<f name="ctag">':
								inside_sublex_ctag = True
								# to avoid crashes when xml is malformed
								lxctag = ""
							if inside_sublex_ctag:
								if re.match('<symbol value=".*"/>', line):
									lxctag = gctag(line)
									# print 'CTAG: '+lxctag
								if line == '</f>':
									inside_sublex_ctag = False
									if lxctag == '':
										# print 'lxctag problem' 
										error.write('ctag error: '+path+'\t'+par+'\t'+sent+'\t'+orth.encode("utf-8")+'\n')
							if line == '<f name="msd">':
								inside_sublex_msd = True
							if inside_sublex_msd:
								if re.match('<symbol .+-msd.*"/>', line):
									lxmsd = gmsd(line)
									msdid = lxmsd[1]
									# print 'MSDID: '+msdid
									# msdtag may be empty (no tags apart from ctag)
									msdtag = lxmsd[0]
									# print 'MSDTAG: '+msdtag
									idct[msdid] = [lxbase, lxctag, msdtag]
								if line == '</f>':
									inside_sublex_msd = False
					if line == '<f name="disamb">':
						inside_disamb = True
					if inside_disamb:
						if re.match('<f fVal="#morph_.+-msd" name="choice"/>', line):
							disamb = gdsmb(line)
							# print 'DISAMB: '+disamb
							# print 'INTERPDICT: '+str(idct)
							chosen = idct[disamb]
							chbase = chosen[0]
							chtag = chosen[1]
							if chosen[2] != "":
								chtag = ":".join(chosen[1:])
							localinterp.append(chbase)
							# print 'BASE: '+chbase
							localinterp.append(chtag)
							# print 'TAG: '+chtag
							# print 'LOCALINTERP: '+str(localinterp)
							if len(localinterp) == 3:
								tokens.append(localinterp)
							inside_disamb = False
				if line == '</s>':
					all_interps.append((par, sent, tokens))
	return all_interps

def glueSent(interps):
	words = []
	for interp in interps:
		words.append(interp[0])
	# return ' '.join(words)
	return ' '.join(words).encode("utf-8")

def wyklucz_powtorki(terminals, dct):
	for terminal in terminals:
		token = terminal[0]
		# attention! just for testing the tokenizer!
		# token = terminal[0].lower()
		haslo = terminal[1]
		tag = terminal[2]
		if not dct.has_key(token):
			dct[token] = [terminal]
		else:
			if terminal not in dct[token]:
				dct[token].append(terminal)

def sortPaths(paths):
    dct = {}
    for path in paths:
	    answer = getAnswer(path)
	    if dct.has_key(answer):
		    if path not in dct[answer]:
			    dct[answer].append(path)
	    else:
		    dct[answer] = [path]
    return dct

def getAnswer(path):
    for line in open(path):
        line = line.strip()
	if re.search('<base-answer type=".+" username=.*>', line):
		answer = line.strip().split()[1][6:-1]
		return answer

if nkjp_paths:
	for path in nkjp_paths:
		# print path+'----------------------------------------------------------------'
		all_path_data = getNKJPmorfsyn(path)
		for (par_id, sent_id, interps) in all_path_data:
			# print 'par: '+par_id+' sent: '+sent_id+' '+str(interps)
			# dict
			localdct = {}
			nkjp_subdir = path.split('/')[-2]
			nkjp_par = par_id
			nkjp_sent = sent_id
			#segm
			segmfile = open(outpath+'NKJP1M'+'_'+nkjp_subdir+'_'+nkjp_par+'_'+nkjp_sent+'__'+'TEST', 'w')
			if interps:
				segmfile.write(glueSent(interps)+'\n\n')
				wyklucz_powtorki(interps, localdct)
				zapisz_slownik(outpath+'NKJP1M'+'_'+nkjp_subdir+'_'+nkjp_par+'_'+nkjp_sent+'__'+'DICT', localdct)
			else:
				badinterp.write(path+'\n\n')

extraInfo(punct, testsuite, "punct")
extraInfo(oov, testsuite, "oov")
