#!/usr/bin/env runhaskell
-- Baseline unigram tagger
-- Jon Dehdari, 2009-2010
-- Usage: ./unigram-tag  +RTS -N2 -RTS  [--default <tag>]  trainFile testFile [outFile]
-- Compile: ghc --make -fforce-recomp -threaded unigram-tag.hs -o unigram-tag
-- License: GPLv3 (see http://www.gnu.org/licenses/gpl3.html)

import System.Environment (getArgs)
import Data.List (sort)
--import Data.Char (toLower,isDigit)
import qualified Data.Map as Map
import Control.Arrow ((&&&))
--import Control.Parallel.Strategies
--import Control.Parallel

---- Language Model is a list of (count, word, tag)
type Model        = [(Int, String, String)]
type ModelMap     = Map.Map String String

main :: IO ()
main = do
    let usage = ("\
\Usage: ./unigram-tag  [options]  trainfile  [testfile  [outfile]]\n\
\Unigram-tag: A small, fast unigram tagger\n\
\\n\
\Options:\n\
\    --default <tag>   Use <tag> for unseen words (default: find most-frequent tag for singletons)\n\
\ ")

    let columnSep = '\t'
    ------ For development
    -- let args = ["train.txt", "test.txt", "out.out"]
    args <- getArgs
    let (_:defaultTagArg:trainFileName:testFileName:outFileName:_) = parseArgs usage args
    train <- readFile trainFileName
    test  <- if testFileName == "_STDIN_"
             then getContents
	     else readFile testFileName

    let model = splitColumn columnSep . amalgamate . initAmalgamate . sort . lines $! train
    -- Boil down model to only most-frequently-occurring tags for each word
    let modelMap = Map.fromList . Prelude.map (\(_,x,y) -> (x,y)) . sort $! model
    let defaultTag = if defaultTagArg /= "_NONE_" then columnSep:defaultTagArg else mostFreqTag model

    if outFileName == "_STDOUT_"
        then putStr . format . tagWords defaultTag modelMap . lines $ test
        else writeFile outFileName . format . tagWords defaultTag modelMap . lines $ test

-------------------------------------------------------------------------------

-- A parallel map
--pMap f xs = parBuffer 4 rwhnf $ map f xs
pMap f xs = map f xs

--strict :: (a -> b) -> a -> b
--strict f x = seq x (f x)


---- IO stuff
format :: [(String, String)] -> String
format x =
    unlines [ a++b | (a,b) <- x]

parseArgs :: String -> [String] -> [String]
parseArgs usage args
  | any (=="--help") args		= error usage
  | any (=="-h") args			= error usage
  | not $ any (=="--default") args	= parseArgs usage ("--default":"_NONE_":args)
  | length args == 3			= parseArgs usage (args ++ ["_STDIN_"])
  | length args == 4			= parseArgs usage (args ++ ["_STDOUT_"])
  | length args == 5			= args
  | otherwise				= error usage

---- Build Model
initAmalgamate :: [String] -> [(String, Int)]
initAmalgamate x =
    [ (a,1) | a <- x ]

--- Elements should be pre-sorted
amalgamate :: [(String, Int)] -> [(Int, String)]
amalgamate [] = error "Empty training set (in amalgamate function)"
amalgamate x  =
    map (snd &&& fst) . merge $ x

merge :: [(String, Int)] -> [(String, Int)]
merge [a] = [a]
merge ((a1,b1):(a2,b2):rest)
  | a1==a2     = merge ((a1,b1+b2):rest)
  | otherwise  = (a1,b1):merge ((a2,b2):rest)

--- foldr-able version of merge
--merge :: (Eq a, Num b) => (a, b) -> [(a, b)] -> [(a, b)]
--merge (a1,b1) [] = [(a1,b1)]
--merge (a1,b1) ((a2,b2):rest)
--  | a1==a2     = ((a1,b1+b2):rest)
--  | otherwise  = ((a1,b1):(a2,b2):rest)

splitColumn :: Char -> [(Int, String)] -> Model
splitColumn columnSep x =
    [joinTuple a (break (==columnSep) b) | (a,b) <- x]
        where
	joinTuple a (b,c) = (a,b,c)

mostFreqTag :: Model -> String
mostFreqTag model =
    snd . maximum . amalgamate . sort . map (snd &&& fst) . map relevantTuple . findSingletons $ model
        where
        relevantTuple (x,_,y) = (x,y)
	findSingletons xs = [(a,b,c) | (a,b,c) <- xs, a==1 ]

---- Test words
tagWord :: String -> ModelMap -> String -> (String, String)
tagWord defaultTag model w =
    (w, tag)
        where
          tag
            | Map.member w model = model Map.! w
	    --- Try lowercased word; actually lowers accuracy
            -- | Map.member ((toLower (head w)) : (tail w)) model = model Map.! ((toLower (head w)) : (tail w))
	    --- Try numbers; raises accuracy, but only semi-language-portable
	    -- | isDigit (head w) = "\tCD"
            | otherwise = defaultTag

tagWords :: String -> ModelMap -> [String] -> [(String, String)]
tagWords defaultTag model [] = error "No words to tag (in tagWords function)"
tagWords defaultTag model ws =
    pMap (tagWord defaultTag model) ws

