#!/usr/bin/env python3

import argparse
from collections import defaultdict, namedtuple
import copy
from enum import Enum
import gzip
import logging
from operator import attrgetter, itemgetter
import os
from pathlib import Path
import re
import sys
from time import sleep
import unittest

# Python 3.14
from compression import zstd

from tqdm import tqdm

logging.basicConfig(
    level=os.environ.get('LOG_LEVEL', logging.INFO),
    format='%(levelname)s %(message)s')
logger = logging.getLogger(__name__)

# Don't print docstrings in unittest's verbose mode
unittest.TestCase.shortDescription = lambda x: None

Item = namedtuple('Item', ['docid', 'type', 'rank', 'score'])

SCRIPT_HOME = Path(__file__).parent
TEST_SUBMISSION = SCRIPT_HOME / 'recsys_test_run'
TEST_QRELS = SCRIPT_HOME / 'recsys_test_qrels'

def debug(*args):
    logging.debug(' '.join(map(str, args)))

class Submission:
    '''A TREC 2025 Product track recommendation task submission.
    '''
    def __init__(self, filename):
        self.related = defaultdict(list)
        self.complementary = defaultdict(list)
        self.substitute = defaultdict(list)
        self.topics = set()
        self.runtag = None

        if isinstance(filename, str):
            filename = Path(filename)

        logging.debug('Loading submission')
        TOPIC_RE = re.compile(r'^(PSRT_Recs_\d+)([RCS])$')
        if filename.suffix == '.gz':
            fp = gzip.open(filename, 'r')
        else:
            fp = open(filename, 'r')

        seen = defaultdict(set)
        for lineno, line in enumerate(fp):
            topic, q0, docid, rank, sim, runtag = line.strip().split()
            if not self.runtag:
                self.runtag = runtag

            rank = int(rank)
            sim = float(sim)
            result = TOPIC_RE.match(topic)
            if not result:
                sys.exit(f'Run has bogus topic {topic} at line {lineno}')
            if docid in seen[topic]:
                sys.exit(f'Run has duplicate document {docid} for topic {topic} at line {lineno}')
            seen[topic].add(docid)

            tid, ranktype = result.group(1, 2)
            self.topics.add(tid)
            if ranktype == 'R':
                self.related[tid].append(Item(docid, q0, rank, sim))
            elif ranktype == 'S':
                self.substitute[tid].append(Item(docid, ranktype, rank, sim))
            elif ranktype == 'C':
               self.complementary[tid].append(Item(docid, ranktype, rank, sim))
            else:
                sys.exit(f'Run has topic with bogus ranking indicator {topic} and line {lineno}')

        fp.close()

        for tid in self.topics:
            self.related[tid] = sorted(self.related[tid], key=attrgetter('score'), reverse=True)
            self.substitute[tid] = sorted(self.substitute[tid], key=attrgetter('score'), reverse=True)
            self.complementary[tid] = sorted(self.complementary[tid], key=attrgetter('score'), reverse=True)

class SubmissionTesting(unittest.TestCase):
    def setUp(self):
        self.test_submission_file = TEST_SUBMISSION
        self.assertTrue(self.test_submission_file.exists())

    def test_read_submission(self):
        subm = Submission(self.test_submission_file)
        self.assertIsInstance(subm, Submission)
        self.assertEqual(len(subm.related), len(subm.topics))
        self.assertEqual(len(subm.substitute), len(subm.topics))
        self.assertEqual(len(subm.complementary), len(subm.topics))
        for topic in subm.topics:
            self.assertLessEqual(len(subm.related[topic]), 100)
            related_last_score = None
            for item in subm.related[topic]:
                if related_last_score is not None:
                    self.assertLessEqual(item.score, related_last_score)
                related_last_score = item.score
            self.assertLessEqual(len(subm.substitute[topic]), 10)
            subst_last_score = None
            for item in subm.related[topic]:
                if subst_last_score is not None:
                    self.assertLessEqual(item.score, subst_last_score)
                subst_last_score = item.score
            self.assertLessEqual(len(subm.complementary[topic]), 10)
            compl_last_score = None
            for item in subm.related[topic]:
                if compl_last_score is not None:
                    self.assertLessEqual(item.score, compl_last_score)
                compl_last_score = item.score


class Judgment():
    '''Convenience class to let us compare judgment levels.
    '''
    def __init__(self, level):
        self.level = str(level)
        chars = list(self.level)
        if len(chars) == 1:
            self.type = None
            self.value = int(level)

        else:
            first, second = chars
            if first in ['C', 'S']:
                self.type = first
                self.value = int(second)
            else:
                self.type = None
                self.value = -1

    def __lt__(self, other):
        if self.__class__ is other.__class__:
            return (self.type is None or other.type is None or self.type == other.type) and self.value < other.value
        elif other.__class__ == int:
            return self.value < other
        return NotImplemented
    
    def __hash__(self):
        # Cantor pairing function
        k1 = self.type.__hash__()
        k2 = self.value.__hash__()
        return int(((k1 + k2) / (k1 + k2 + 1)) / 2) + k2
    
    def __eq__(self, other):
        return (self.__class__ is other.__class__) and (self.type == other.type) and (self.value == other.value)
    
    def __str__(self):
        return f'{self.type}{self.value}'
    
    def same_type(self, other):
        return self.__class__ is other.__class__ and self.type and other.type and self.type == other.type

    def is_type(self, typechar):
        return self.type == typechar

    def convert(things):
        '''A class method to take a list of labels and turn them into a set of Judgments'''
        if not isinstance(things, list):
            things = [things]
        return set([Judgment(foo) for foo in things])

class JudgmentTests(unittest.TestCase):
    def test_judgments(self):
        one = Judgment('1') 
        zero = Judgment('0')
        self.assertLess(zero, one)
        S0 = Judgment('S0')
        S1 = Judgment('S1')
        self.assertLess(S0, S1)
        S2 = Judgment('S2')
        self.assertLess(S1, S2)
        UA = Judgment('UA')
        self.assertLess(UA, S0)
        C1 = Judgment('C1')
        self.assertFalse(C1 < S2)
        self.assertTrue(S0.same_type(S1))
        self.assertFalse(C1.same_type(S1))
        self.assertTrue(S2.is_type('S'))
        self.assertFalse(C1.is_type('S'))

    def test_convert(self):
        self.assertEqual(set([Judgment('1')]), Judgment.convert('1'))
        labels = ['S0', 'S1', 'S2']
        judgments = set([ Judgment('S2'), Judgment('S1'), Judgment('S0') ])
        self.assertEqual(judgments, Judgment.convert(labels))


class Qrels:
    '''Relevance judgments.
    '''
    def __init__(self, filename):
        self.rels = defaultdict(dict)
        self.levels = set()
        with open(filename, 'r') as qrels_file:
            for lineno, line in enumerate(qrels_file):
                topic, _, docid, rel = line.strip().split()
                if docid in self.rels[topic]:
                    logging.warning(f'Qrels {lineno}: document {docid} appears more than once for topic {topic}, keeping last judgment')
                rel = Judgment(rel)
                self.rels[topic][docid] = rel
                self.levels.add(rel)

    def get(self, topic, docid):
        if topic in self.rels and docid in self.rels[topic]:
            return self.rels[topic][docid]
        return None
    
    def get_topic(self, topic):
        if topic in self.rels:
            return self.rels[topic]
        return None

    def num_topics(self):
        return len(self.rels)
    
    def topics(self):
        return self.rels.keys()
    
    def rel_levels(self):
        return self.levels

class QrelsTests(unittest.TestCase):
    def setUp(self):
        self.assertTrue(TEST_QRELS.exists())
        self.qrels = Qrels(TEST_QRELS)

    def test_check_levels(self):
        correct_levels = set(Judgment(foo) for foo in ['S0', 'S1', 'S2', 'C0', 'C1', 'C2', 'UA', 'NR'])
        self.assertEqual(self.qrels.rel_levels(), correct_levels)

    def test_check_topics(self):
        self.assertEqual(self.qrels.num_topics(), 47)


class Measure:
    '''Measures are the base class for evaluation measures computed using the
    EvalJig.
    self.formatstr: how to format the measure score as a string
    compute(): compute the measure
    redux(): compute an summary (for example, an average or max) of scores
    pretty() and pretty_mean() are used by the EvalJig's output routines.
    '''
    def __init__(self, maxdepth=None):
        self.formatstr = '{:.4f}'
        self.label = 'measure name'
        self.print_score = True
        self.print_mean = True
        self.jig = None
        self.maxdepth = None
    def __str__(self): return self.label
    def compute(self, topic, resp, qrel):
        return 0
    def redux(self, scores):
        return sum(scores.values()) / len(scores.values())
    def pretty(self, topic, value, prefix=None, runtag=None):
        row = []
        if runtag:
            row.append(runtag)
        if prefix:
            row.append(prefix)
        row.extend([topic, self.label, self.formatstr.format(value)])
        return " ".join(row)
    def pretty_mean(self, value, prefix=None, runtag=None):
        '''Print the measure's mean value.'''
        row = []
        if runtag:
            row.append(runtag)
        if prefix:
            row.append(prefix)

        row.extend(["all", self.label, self.formatstr.format(value)])
        return " ".join(row)

class EvalJig:
    def __init__(self, label=None):
        self.ops = []
        self.score = defaultdict(dict)
        self.means = dict()
        self.topics = set()
        self.label = label

    def add_op(self, meas):
        self.ops.append(meas)
        meas.jig = self

    def compute(self, topic, ranking, qrel):
        self.topics.add(topic)
        for op in self.ops:
            self.score[str(op)][topic] = op.compute(topic, ranking, qrel)

    def zero(self, topic):
        self.topics.add(topic)
        for op in self.ops:
            self.score[str(op)][topic] = 0

    def comp_means(self):
        for op in self.ops:
            self.means[str(op)] = op.redux(self.score[str(op)])

    def print_scores_for(self, topic, runtag=None):
        for op in self.ops:
            if not op.print_score: continue
            opname = str(op)
            if topic in self.score[opname]:
                print(op.pretty(topic, self.score[opname][topic], prefix=self.label, runtag=runtag))

    def print_scores(self, runtag=None):
        for topic in sorted(self.topics):
            self.print_scores_for(topic, runtag)

    def print_means(self, runtag=None):
        for op in self.ops:
            if not op.print_mean: continue
            opname = str(op)
            if self.means[opname] is not None:
                print(op.pretty_mean(self.means[opname], prefix=self.label, runtag=runtag))


class NumQueries(Measure):
    '''This is a summary-only measure giving the number of topics.'''
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.formatstr = '{:d}'
        self.label = 'num_q'
        self.topics = set()
        self.print_score = False
    def compute(self, topic, ranking, qrel):
        self.topics.add(topic)
        return 0
    def redux(self, scores):
        return len(self.topics)

class NumQueriesTests(unittest.TestCase):
    def test_num_q(self):
        self.assertTrue(TEST_SUBMISSION.exists())
        subm = Submission(TEST_SUBMISSION)
        jig = EvalJig()
        jig.add_op(NumQueries())
        for topic in subm.topics:
            jig.compute(topic, subm.related[topic], None)
        jig.comp_means()
        self.assertEqual(len(subm.topics), jig.means['num_q'])


class NumRetrieved(Measure):
    '''The number of documents retrieved above the args.depth cutoff.'''
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.formatstr = '{:d}'
        self.label = 'num_ret'
    def compute(self, topic, ranking, qrel):
        return len(ranking[:self.maxdepth])
    def redux(self, scores):
        return sum(scores.values())

class NumRetrievedTests(unittest.TestCase):
    def test_num_ret(self):
        self.assertTrue(TEST_SUBMISSION.exists())
        subm = Submission(TEST_SUBMISSION)
        jig = EvalJig()
        jig.add_op(NumRetrieved())
        total_ret = 0
        for topic in subm.topics:
            jig.compute(topic, subm.related[topic], None)
            total_ret += jig.score['num_ret'][topic]
            self.assertEqual(len(subm.related[topic]), jig.score['num_ret'][topic])
        jig.comp_means()
        self.assertEqual(total_ret, jig.means['num_ret'])


class PrecAtCutoff(Measure):
    '''Precision at the rank cutoff; the fraction of documents above the 
    cutoff that are relevant (rel >= args.minrel.
    The true cutoff is the minimum of args.depth, the ranking length,
    and the cutoff.'''
    def __init__(self, rel_labels, cutoff=None, **kwargs):
        super().__init__(**kwargs)
        self.rel = Judgment.convert(rel_labels)
        self.cutoff = cutoff
        self.label = f'P@{self.cutoff or "end"}'
    def compute(self, topic, ranking, qrel):
        cut = len(ranking)
        if self.cutoff:
            cut = min(cut, self.cutoff)
        if self.maxdepth:
            cut = min(cut, self.maxdepth)
        if cut == 0:
            return 0.0
        s = sum(1 for doc in ranking[:cut]
                if doc.docid in qrel and qrel[doc.docid] in self.rel)
        return (s / cut)
    
class PrecAtCutoffTests(unittest.TestCase):
    def setUp(self):
        self.qrels = { 'a': Judgment(0), 'b': Judgment(0), 'c': Judgment(1) }
        self.prec = PrecAtCutoff(rel_labels = [1])

    def test_precision(self):
        ranking = [ 
            Item('a', None, 0, 100), 
            Item('b', None, 1, 99),
            Item('c', None, 2, 98)
        ]
        value = self.prec.compute('foo', ranking, self.qrels)
        self.assertAlmostEqual(value, float(1/3))

    def test_prec_no_rel_ret(self):
        ranking = [
            Item('a', None, 0, 100),
        ]
        value = self.prec.compute('foo', ranking, self.qrels)
        self.assertEqual(value, 0)

    def test_prec_empty_ranking(self):
        ranking = []
        value = self.prec.compute('foo', ranking, self.qrels)
        self.assertEqual(value, 0)

    def test_prec_nonexistent_doc(self):
        ranking = [ Item('mystery', None, 0, 100) ]
        value = self.prec.compute('foo', ranking, self.qrels)
        self.assertEqual(value, 0)

    def test_prec_at_cutoff(self):
        ranking = [ 
            Item('a', None, 0, 100), 
            Item('b', None, 1, 99),
            Item('c', None, 2, 98)
        ]
        prec = PrecAtCutoff(rel_labels = [1], cutoff=2)
        value = prec.compute('foo', ranking, self.qrels)
        self.assertEqual(value, 0)


class RecallAtCutoff(Measure):
    '''Recall at the rank cutoff; the fraction of relevant documents above the 
    cutoff out of all known documents that are relevant (rel >= args.minrel.
    The true cutoff is the minimum of args.depth, the ranking length,
    and the cutoff.'''
    def __init__(self, rel_labels, cutoff=None, **kwargs):
        super().__init__(**kwargs)
        self.rel = Judgment.convert(rel_labels)
        self.cutoff = cutoff
        self.label = f'R@{self.cutoff}'
    def compute(self, topic, ranking, qrel):
        numrel = sum(1 for rel in qrel.values() if rel in self.rel)
        if numrel == 0:
            return 0.0
        cut = len(ranking)
        if self.cutoff:
            cut = min(cut, self.cutoff)
        if self.maxdepth:
            cut = min(cut, self.maxdepth)
        if cut == 0:
            return 0.0
        s = sum(1 for doc in ranking[:cut]
                if doc.docid in qrel and qrel[doc.docid] in self.rel)
        return (s / numrel)

class RecallAtCutoffTests(unittest.TestCase):
    def setUp(self):
        self.qrels = { 'a': Judgment(0), 'b': Judgment(0), 'c': Judgment(1) }
        self.recall = RecallAtCutoff(rel_labels = [1])

    def test_recall(self):
        ranking = [ 
            Item('a', None, 0, 100), 
            Item('b', None, 1, 99),
            Item('c', None, 2, 98)
        ]
        value = self.recall.compute('foo', ranking, self.qrels)
        self.assertEqual(value, 1.0)

    def test_recall_no_rel_ret(self):
        ranking = [
            Item('a', None, 0, 100),
        ]
        value = self.recall.compute('foo', ranking, self.qrels)
        self.assertEqual(value, 0)

    def test_recall_no_relevant(self):
        qrels = { 'a': Judgment(0), 'b': Judgment(0) }
        ranking = [ 
            Item('a', None, 0, 100), 
            Item('b', None, 1, 99),
            Item('c', None, 2, 98)
        ]
        value = self.recall.compute('foo', ranking, qrels)
        self.assertEqual(value, 0)

    def test_recall_empty_ranking(self):
        ranking = []
        value = self.recall.compute('foo', ranking, self.qrels)
        self.assertEqual(value, 0)

    def test_recall_nonexistent_doc(self):
        ranking = [ Item('mystery', None, 0, 100) ]
        value = self.recall.compute('foo', ranking, self.qrels)
        self.assertEqual(value, 0)

    def test_recall_at_cutoff(self):
        ranking = [ 
            Item('a', None, 0, 100), 
            Item('b', None, 1, 99),
            Item('c', None, 2, 98)
        ]
        recall = RecallAtCutoff(rel_labels = [1], cutoff=2)
        value = recall.compute('foo', ranking, self.qrels)
        self.assertEqual(value, 0)


class nDCG(Measure):
    '''Normalized discounted cumulative gain.
    The provided gain_map gives a mapping of relevance values to gains.
    At each rank, we accumulate the gain for each retrieved document.
    The gain at each rank is discounted by dividing by the log of the rank.
    This cumulative discounted gain is normalized by the ideal DCG, which is
    the DCG obtained by retrieving all documents in decreasing order of gain.
    Reference: Jarvelin and Kekalainen, "Cumulated gain-based evaluation of IR
    techniques", ACM TOIS 20(4):422-446, 2002.
    '''
    def __init__(self, 
                 gain_mapping={Judgment("0"): 0, 
                               Judgment("1"): 1,
                               Judgment("2"): 2, 
                               Judgment("3"): 5, 
                               Judgment("4"): 10},
                 cutoff=None,
                 label_prefix='nDCG', 
                 **kwargs):
        super().__init__(**kwargs)
        self.cutoff = cutoff
        self.gains = {}
        for gain, value in gain_mapping.items():
            if isinstance(gain, str):
                gain = Judgment(gain)
            self.gains[gain] = value

        if self.cutoff:
            self.label = f'{label_prefix}_{self.cutoff}'
        else:
            self.label = f'{label_prefix}'

    def _gain(self, doc, qrels):
        if doc not in qrels:
            return 0.0
        rel = qrels[doc]
        if rel not in self.gains:
            logging.warning(f'Relevance judgment {rel} for document {doc} has an unknown gain value')
        return self.gains.get(rel, 0)

    def _discount(self, rank):
        from math import log2
        return log2(rank + 1)

    def _add_gains(self, ranking, qrels):
        # cutoff = self.cutoff or sum(1 for doc in qrels if self._gain(doc, qrels) > 0)
        gain_vec = []
        #for i in range(cutoff):
        #    if i >= len(ranking):
        #        break
        for item in ranking:
            doc = item.docid
            gain = self._gain(doc, qrels)
            gain_vec.append({'doc': doc,
                             'gain': gain})
        return gain_vec

    def _rank_discount(self, gain_vec, start_rank=1):
        for i in range(len(gain_vec)):
            discount = self._discount(i + start_rank)
            gain_vec[i]['gain'] /= discount
        return gain_vec

    def compute(self, topic, ranking, qrels):
        debug('nDCG', 'ranking', ranking)
        gain_vec = self._add_gains(ranking, qrels)
        debug('nDCG', 'gain vector', gain_vec)
        gain_vec = self._rank_discount(gain_vec)
        debug('nDCG', 'discounted gain vector', gain_vec)
        dcg = sum(d['gain'] for d in gain_vec)
        debug('nDCG', 'DCG', dcg)

        ideal_resp = [ Item(doc, qrels[doc].type, 1, 1)
                      for doc in qrels if self._gain(doc, qrels) > 0]
        debug('nDCG', 'Ideal response', ideal_resp)
        ideal_gain_vec = self._add_gains(ideal_resp, qrels)
        debug('nDCG', 'Ideal gain vector', ideal_gain_vec)
        ideal_gain_vec = sorted(ideal_gain_vec,
                                key=itemgetter('gain'),
                                reverse=True)
        debug('nDCG', 'ideal sorted gain vector', ideal_gain_vec)
        ideal_gain_vec = self._rank_discount(ideal_gain_vec)
        debug('nDCG', 'discounted gain vec', ideal_gain_vec)
        idcg = sum(d['gain'] for d in ideal_gain_vec)
        debug('nDCG', 'ideal DCG', idcg)

        if idcg > 0.0:
            debug('nDCG', 'value', f'{dcg/idcg:2f}')
            return dcg / idcg
        else:
            debug('nDCG', 'value', 0.0)
            return 0.0

class nDCGTests(unittest.TestCase):
    def setUp(self):
        self.qrels = { 'a': Judgment(1), 'b': Judgment(0), 'c': Judgment(2), 'd': Judgment(2)}
        self.gains = { Judgment(1): 1, Judgment(2): 10, Judgment(0): 0 }
        self.ranking = [ 
            Item('a', None, 1, 100), 
            Item('b', None, 2, 99),
            Item('c', None, 3, 98)
        ]
        self.ndcg = nDCG(gain_mapping=self.gains)

    def test_gain_value(self):
        self.assertEqual(self.ndcg._gain('a', self.qrels), 1)
        self.assertEqual(self.ndcg._gain('foo', self.qrels), 0)
        self.assertEqual(self.ndcg._gain(None, self.qrels), 0)

    def test_discount(self):
        self.assertEqual(self.ndcg._discount(0), 0)
        self.assertEqual(self.ndcg._discount(1), 1)
        self.assertAlmostEqual(self.ndcg._discount(100), 6.65821148)
        with self.assertRaises(ValueError):
            self.ndcg._discount(-1)

    def test_gain_vec(self):
        gains = self.ndcg._add_gains(self.ranking, self.qrels)
        for gitem in gains:
            self.assertEqual(gitem['gain'], self.ndcg._gain(gitem['doc'], self.qrels))

    def test_compute_ndcg(self):
        '''
        nDCG ranking [Item(docid='a', type=None, rank=1, score=100), Item(docid='b', type=None, rank=2, score=99), Item(docid='c', type=None, rank=3, score=98)]
        nDCG gain vector [{'doc': 'a', 'gain': 1}, {'doc': 'b', 'gain': 0}, {'doc': 'c', 'gain': 10}]
        nDCG discounted gain vector [{'doc': 'a', 'gain': 1.0}, {'doc': 'b', 'gain': 0.0}, {'doc': 'c', 'gain': 5.0}]
        nDCG DCG 6.0
        nDCG Ideal response [Item(docid='a', type=None, rank=1, score=1), Item(docid='c', type=None, rank=1, score=1), Item(docid='d', type=None, rank=1, score=1)]
        nDCG Ideal gain vector [{'doc': 'a', 'gain': 1}, {'doc': 'c', 'gain': 10}, {'doc': 'd', 'gain': 10}]
        nDCG ideal sorted gain vector [{'doc': 'c', 'gain': 10}, {'doc': 'd', 'gain': 10}, {'doc': 'a', 'gain': 1}]
        nDCG discounted gain vec [{'doc': 'c', 'gain': 10.0}, {'doc': 'd', 'gain': 6.309297535714575}, {'doc': 'a', 'gain': 0.5}]
        nDCG ideal DCG 16.809297535714574
        nDCG value 0.356945
        '''
        value = self.ndcg.compute(None, self.ranking, self.qrels)
        self.assertAlmostEqual(value, 0.3569453)

    def test_compute_ndcg_unknown_doc(self):
        value = self.ndcg.compute(None, [
            Item('foo', None, 1, 100),
            Item('bar', None, 2, 99),
            Item('baz', None, 3, 98)
        ], self.qrels)
        self.assertEqual(value, 0.0)
    
    def test_compute_ndcg_empty_ranking(self):
        value = self.ndcg.compute(None, [], self.qrels)
        self.assertEqual(value, 0.0)

class PoolNDCG(nDCG):
    '''Special measure for product recommendation. This is NDCG of the related
    ranking, but items that are misclassified only get 50% of the gain.
    '''
    def __init__(self, 
                 gain_mapping={
                     Judgment("S2"): 2, 
                     Judgment("S1"): 1, 
                     Judgment("S0"): 0, 
                     Judgment("C2"): 2, 
                     Judgment("C1"): 1, 
                     Judgment("C0"): 0, 
                     Judgment("NR"): 0, 
                     Judgment("UA"): 0},
                 cutoff=None, 
                 **kwargs):
        super().__init__(gain_mapping, cutoff, label_prefix='PoolNDCG', **kwargs)

    def _add_gains(self, ranking, qrels):
        gain_vec = super()._add_gains(ranking, qrels)
        debug('PoolNDCG', 'superclass gain_vec', gain_vec)
        for item, gain in zip(ranking, gain_vec):
            if gain['gain'] == 0.0:
                continue
            if not qrels[item.docid].is_type(item.type):
                gain['gain'] /= 2.0
        debug('PoolNDCG', 'updated gain_vec', gain_vec)
        return gain_vec

class PoolNDCGTests(unittest.TestCase):
    def setUp(self):
        self.ndcg = PoolNDCG()
        self.qrels = { 'a': Judgment('S2'), 
                       'b': Judgment('S1'), 
                       'c': Judgment('S0'), 
                       'd': Judgment('C2'), 
                       'e': Judgment('C1'), 
                       'f': Judgment('C0'), 
                       'g': Judgment('UA'), 
                       'h': Judgment('NR'), 
        }

    def test_poolndcg(self):
        '''
        nDCG ranking [Item(docid='a', type='S', rank=1, score=10), Item(docid='b', type='C', rank=2, score=9), Item(docid='c', type='S', rank=3, score=8), Item(docid='g', type='C', rank=4, score=7), Item(docid='h', type='S', rank=5, score=6), Item(docid='d', type='C', rank=6, score=5), Item(docid='e', type='S', rank=7, score=4), Item(docid='f', type='C', rank=8, score=3)]
        PoolNDCG superclass gain_vec [{'doc': 'a', 'gain': 2}, {'doc': 'b', 'gain': 1}, {'doc': 'c', 'gain': 0}, {'doc': 'g', 'gain': 0}, {'doc': 'h', 'gain': 0}, {'doc': 'd', 'gain': 2}, {'doc': 'e', 'gain': 1}, {'doc': 'f', 'gain': 0}]
        PoolNDCG updated gain_vec [{'doc': 'a', 'gain': 2}, {'doc': 'b', 'gain': 0.5}, {'doc': 'c', 'gain': 0}, {'doc': 'g', 'gain': 0.0}, {'doc': 'h', 'gain': 0.0}, {'doc': 'd', 'gain': 2}, {'doc': 'e', 'gain': 0.5}, {'doc': 'f', 'gain': 0}]
        nDCG gain vector [{'doc': 'a', 'gain': 2}, {'doc': 'b', 'gain': 0.5}, {'doc': 'c', 'gain': 0}, {'doc': 'g', 'gain': 0.0}, {'doc': 'h', 'gain': 0.0}, {'doc': 'd', 'gain': 2}, {'doc': 'e', 'gain': 0.5}, {'doc': 'f', 'gain': 0}]
        nDCG discounted gain vector [{'doc': 'a', 'gain': 2.0}, {'doc': 'b', 'gain': 0.31546487678572877}, {'doc': 'c', 'gain': 0.0}, {'doc': 'g', 'gain': 0.0}, {'doc': 'h', 'gain': 0.0}, {'doc': 'd', 'gain': 0.7124143742160444}, {'doc': 'e', 'gain': 0.16666666666666666}, {'doc': 'f', 'gain': 0.0}]
        nDCG DCG 3.1945459176684396
        nDCG Ideal response [Item(docid='a', type='S', rank=1, score=1), Item(docid='b', type='S', rank=1, score=1), Item(docid='d', type='C', rank=1, score=1), Item(docid='e', type='C', rank=1, score=1)]
        PoolNDCG superclass gain_vec [{'doc': 'a', 'gain': 2}, {'doc': 'b', 'gain': 1}, {'doc': 'd', 'gain': 2}, {'doc': 'e', 'gain': 1}]
        PoolNDCG updated gain_vec [{'doc': 'a', 'gain': 2}, {'doc': 'b', 'gain': 1}, {'doc': 'd', 'gain': 2}, {'doc': 'e', 'gain': 1}]
        nDCG Ideal gain vector [{'doc': 'a', 'gain': 2}, {'doc': 'b', 'gain': 1}, {'doc': 'd', 'gain': 2}, {'doc': 'e', 'gain': 1}]
        nDCG ideal sorted gain vector [{'doc': 'a', 'gain': 2}, {'doc': 'd', 'gain': 2}, {'doc': 'b', 'gain': 1}, {'doc': 'e', 'gain': 1}]
        nDCG discounted gain vec [{'doc': 'a', 'gain': 2.0}, {'doc': 'd', 'gain': 1.261859507142915}, {'doc': 'b', 'gain': 0.5}, {'doc': 'e', 'gain': 0.43067655807339306}]
        nDCG ideal DCG 4.192536065216308
        nDCG value 0.761960
        '''
        value = self.ndcg.compute(None, [
            Item('a', 'S', 1, 10), # correct class, gain 2
            Item('b', 'C', 2, 9),  # incorrect, gain 1/2
            Item('c', 'S', 3, 8),  # correct, but gain 0
            Item('g', 'C', 4, 7),  # UA, unable to assess
            Item('h', 'S', 5, 6),  # NR, class no one gets
            Item('d', 'C', 6, 5),  # correct, gain 2
            Item('e', 'S', 7, 4),  # incorrect, gain 1/2
            Item('f', 'C', 8, 3),  # correct, but gain 0
        ], self.qrels)
        self.assertAlmostEqual(value, 0.76196027)

class Agreement(Measure):
    '''Agreement of predicted and actual annotations in the related ranking.
    '''
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.label = 'agreement'
    def compute(self, topic, ranking, qrels):
        # confusion matrix formulation (from https://en.wikipedia.org/wiki/Cohen%27s_kappa)
        # kappa = (2 * (TP * TN - FN * FP)) / ((TP+FP) * (FP+TN) + (TP+FN) * (FN+TN))
        ts = tc = fs = fc = 0
        for item in ranking:
            if item.docid not in qrels:
                debug('agreement', f'{item.docid} not in qrels')
                continue
            if qrels[item.docid] == 'UA':
                debug('agreement', f'{item.docid} UA')
                continue
            if item.type == 'S':
                if qrels[item.docid].level.startswith('S'):
                    debug('agreement', f'{item.docid} S agree')
                    ts += 1
                else: # true type is C or NR
                    debug('agreement', f'{item.docid} S not agree')
                    fs += 1
            elif item.type == 'C':
                if qrels[item.docid].level.startswith('C'):
                    debug('agreement', f'{item.docid} C agree')
                    tc += 1
                else: # true type is S or NR
                    debug('agreement', f'{item.docid} C not agree')
                    fc += 1
        debug('agreement', 'ts', ts, 'fs', fs, 'tc', tc, 'fc', fc)
        denom = ((ts + fs) * (fs + tc) + (ts + fc) * (fc + tc))
        if denom != 0.0:
            value = (2.0 * (ts * tc - fc * fs)) / denom
            debug('agreement', '=', value)
            return value
        else:
            debug('agreement', '= 0.0')
            return 0.0

class AgreementTests(unittest.TestCase):
    def setUp(self):
        self.agreement = Agreement()
        self.qrels = { 'a': Judgment('S2'), 
                       'b': Judgment('S1'), 
                       'c': Judgment('S0'), 
                       'd': Judgment('C2'), 
                       'e': Judgment('C1'), 
                       'f': Judgment('C0'), 
                       'g': Judgment('UA'), 
                       'h': Judgment('NR'), 
        }
    def test_no_agreement(self):
        value = self.agreement.compute(None, [
            Item('a', 'S', 1, 10),
            Item('b', 'C', 2, 9),
            Item('c', 'S', 3, 8),
            Item('g', 'C', 4, 7),
            Item('h', 'S', 5, 6),
            Item('d', 'C', 6, 5),
            Item('e', 'S', 7, 4),
            Item('f', 'C', 8, 3),
        ], self.qrels)
        self.assertEqual(value, 0.0)
 
    def test_perfect_agreement(self):
        value = self.agreement.compute(None, [
            Item('a', 'S', 1, 10),
            Item('c', 'S', 3, 8),
            Item('d', 'C', 6, 5),
            Item('f', 'C', 8, 3),
        ], self.qrels)
        self.assertEqual(value, 1.0)
         
    def test_some_agreement(self):
        value = self.agreement.compute(None, [
            Item('a', 'S', 1, 10),
            Item('b', 'S', 2, 9),
            Item('c', 'S', 3, 8),
            Item('d', 'C', 6, 5),
            Item('g', 'C', 4, 7),
        ], self.qrels)
        self.assertAlmostEqual(value, 0.54545454)

    def test_agreement_empty_ranking(self):
        value = self.agreement.compute(None, [], self.qrels)
        self.assertEqual(value, 0.0)

    def test_test_ranking(self):
        subm = Submission(TEST_SUBMISSION)
        qrels = Qrels(TEST_QRELS)
        topic = "PSRT_Recs_002"
        value = self.agreement.compute(topic, subm.related[topic], qrels.get_topic(topic))
        self.assertAlmostEqual(value, 0.15151515)

class Diversity(Measure):
    # The distances mapping is common to all Diversity metric instances.
    # This means we don't have to load the table for both complementary 
    # and substitutes.
    distances = None

    def __init__(self, distances_file=None, **kwargs):
        super().__init__(**kwargs)
        self.label = 'diversity'
        self.MAX_DISTANCE = 2.0

        if distances_file and (Diversity.distances is None):
            logging.info('Diversity: reading distances')
            Diversity.distances = {}
            with zstd.open(distances_file) as dist_file:
                for line in tqdm(dist_file, desc='Loading distances'):
                    line = line.decode()
                    if line.startswith('asin1'):
                        continue
                    asin1, asin2, dist, fdist = line.strip().split(',')
                    Diversity.distances[f'{asin1}-{asin2}'] = float(fdist)

    def compute(self, topic, ranking, qrels):
        if len(ranking) == 0 or len(ranking) == 1:
            return self.MAX_DISTANCE
        dists = []
        for i in range(len(ranking)-1):
            asin1 = ranking[i].docid
            for j in range(i+1, len(ranking)):
                asin2 = ranking[j].docid
                key = f'{asin1}-{asin2}'
                if Diversity.distances:
                    dist = Diversity.distances.get(key, self.MAX_DISTANCE)
                else:
                    dist = self.MAX_DISTANCE
                dists.append(dist)
        return float(sum(dists)) / float(len(dists))


@unittest.skipIf(os.environ.get('SKIP_DIVERSITY', False), 'Skipped due to SKIP_DIVERSITY env var')
class DiversityTests(unittest.TestCase):
    def setUp(self):
        self.diversity = Diversity(distances_file='product-distances.csv.zst')
        self.qrels = Qrels(TEST_QRELS)

    def test_diversity(self):
        value = self.diversity.compute(None, [
            Item('B07QDBDW5Z', 'S', 1, 10),
            Item('B07GYZ5997', 'S', 2, 9),
            Item('B07TKFKLSX', 'S', 3, 8),
            Item('B07J1Y5PH5', 'S', 4, 7),
            Item('B07J2SMZN3', 'S', 5, 6),
            Item('B07H9NCHSS', 'S', 6, 5),
            Item('B07J2SMWRX', 'S', 7, 4),
            Item('B07FRTQKMW', 'S', 8, 3),
            Item('B07VWXRWBZ', 'S', 9, 2),
            Item('B079Y985C6', 'S', 10, 1)
        ], self.qrels)
        self.assertEqual(value, 0.4)

    def test_nonzero_dists(self):
        value = self.diversity.compute(None, [
            Item("B07DSSDXQW", "S", 1, 10),
            Item("B015SBLH98", "S", 2, 9),
            Item("B00N5ADRMA", "S", 3, 8),
            Item("B07F4JLZQK", "S", 4, 7),
            Item("B07F4DTK9M", "S", 5, 6),
            Item("B073P32KVV", "S", 6, 5),
            Item("B073P3DPTS", "S", 7, 4),
            Item("B0711BKFKY", "S", 8, 3),
            Item("B01JPEO84K", "S", 9, 2),
            Item("B00HSGKSR4", "S", 10, 1),
        ], self.qrels)
        self.assertAlmostEqual(value, 1.1037037037037036)

    def test_empty_ranking(self):
        value = self.diversity.compute(None, [], self.qrels)
        self.assertEqual(value, 2.0)

    def test_one_item_ranking(self):
        value = self.diversity.compute(None, [
            Item("B07DSSDXQW", "S", 1, 10),
        ], self.qrels)
        self.assertEqual(value, 2.0)


if __name__ == '__main__':
    ap = argparse.ArgumentParser(
        description='Score a product track recommendation task run',
        formatter_class=argparse.ArgumentDefaultsHelpFormatter
    )
    ap.add_argument('qrels',
                    type=Path,
                    help='Relevance judgments file')
    ap.add_argument('distances',
                    type=Path,
                    help='Product distances file')
    ap.add_argument('runfiles',
                    type=Path,
                    nargs='+',
                    help='The run to be measured')

    ap.add_argument('--log', type=str, default='WARNING')

    ap.add_argument('-v', '--verbose', action='store_true')

    ap.add_argument('-M', '--maxdepth',
                    help='Maximum depth for evaluation',
                    default=None)

    args = ap.parse_args()

    logging.basicConfig(level=getattr(logging, args.log.upper()),
                        format='%(levelname)s {%(funcName)s:%(lineno)d} %(message)s',
                        force=True)
    logger = logging.getLogger(__name__)
    if logger.level < logging.INFO and args.verbose:
        logger.setLevel(logging.INFO)
    logger.debug(f'Logging level set to {args.log}')

    qrels = Qrels(args.qrels)

    # Ooh, problems
    # 1. The jig is built around the idea of many scores for one ranking. 
    #    So for more than one ranking, we need more than one jig.
    #    That's ok, because the different rankings have different metrics.
    # 2. An item has a judgment of substitute or complementary or neither.
    #    When scoring a ranking, we need to know if its the subst or compl ranking.   

    c_jig = EvalJig(label='complementary')
    c_jig.add_op(NumQueries(maxdepth=args.maxdepth))
    c_jig.add_op(NumRetrieved(maxdepth=args.maxdepth))
    c_jig.add_op(PrecAtCutoff(['C2', 'C1'], cutoff=1, maxdepth=args.maxdepth))
    c_jig.add_op(PrecAtCutoff(['C2', 'C1'], cutoff=5, maxdepth=args.maxdepth))
    c_jig.add_op(PrecAtCutoff(['C2', 'C1'], cutoff=10, maxdepth=args.maxdepth))
    c_jig.add_op(RecallAtCutoff(['C2', 'C1'], cutoff=1, maxdepth=args.maxdepth))
    c_jig.add_op(RecallAtCutoff(['C2', 'C1'], cutoff=5,maxdepth=args.maxdepth))
    c_jig.add_op(RecallAtCutoff(['C2', 'C1'], cutoff=10,maxdepth=args.maxdepth))
    c_jig.add_op(nDCG({"C2": 2, "C1": 1, "C0": 0, "S2": 0, "S1": 0, "S0": 0, "NR": 0, "UA": 0}, cutoff=10, maxdepth=args.maxdepth))
    c_jig.add_op(Diversity(distances_file=args.distances))

    s_jig = EvalJig(label='substitute')
    s_jig.add_op(NumQueries(maxdepth=args.maxdepth))
    s_jig.add_op(NumRetrieved(maxdepth=args.maxdepth))
    s_jig.add_op(PrecAtCutoff(['S2', 'S1'], cutoff=1, maxdepth=args.maxdepth))
    s_jig.add_op(PrecAtCutoff(['S2', 'S1'], cutoff=5, maxdepth=args.maxdepth))
    s_jig.add_op(PrecAtCutoff(['S2', 'S1'], cutoff=10, maxdepth=args.maxdepth))
    s_jig.add_op(RecallAtCutoff(['S2', 'S1'], cutoff=1, maxdepth=args.maxdepth))
    s_jig.add_op(RecallAtCutoff(['S2', 'S1'], cutoff=5, maxdepth=args.maxdepth))
    s_jig.add_op(RecallAtCutoff(['S2', 'S1'], cutoff=10, maxdepth=args.maxdepth))
    s_jig.add_op(nDCG({"S2": 2, "S1": 1, "S0": 0, "C2": 0, "C1": 0, "C0": 0, "NR": 0, "UA": 0}, cutoff=10, maxdepth=args.maxdepth))
    s_jig.add_op(Diversity()) # distanced file only needs loading once (phew)

    class DummyOp(Measure):
        '''An operation that doesn't measure up.'''
        def __init__(self, label, **kwargs):
            super().__init__(**kwargs)
            self.label = label

    avg_jig = EvalJig(label='related')
    avg_jig.add_op(DummyOp('avg_ndcg'))
    avg_jig.add_op(Agreement(maxdepth=args.maxdepth))
    avg_jig.add_op(PoolNDCG(cutoff=10, maxdepth=args.maxdepth))

    for runfile in args.runfiles:
        subm = Submission(runfile)
        for topic in qrels.topics():
            c_jig.zero(f'{topic}C')
            c_jig.compute(f'{topic}C', subm.complementary[topic], qrels.get_topic(topic))
            s_jig.zero(f'{topic}S')
            s_jig.compute(f'{topic}S', subm.substitute[topic], qrels.get_topic(topic))

            avg_jig.zero(f'{topic}R')
            avg_jig.compute(f'{topic}R', subm.related[topic], qrels.get_topic(topic))
            avg_jig.score['avg_ndcg'][f'{topic}R'] = (c_jig.score['nDCG_10'][f'{topic}C'] + s_jig.score['nDCG_10'][f'{topic}S']) / 2.0

        c_jig.print_scores(runtag=subm.runtag)
        s_jig.print_scores(runtag=subm.runtag)
        avg_jig.print_scores(runtag=subm.runtag)

        c_jig.comp_means()
        s_jig.comp_means()
        avg_jig.comp_means()

        c_jig.print_means(runtag=subm.runtag)
        s_jig.print_means(runtag=subm.runtag)
        avg_jig.print_means(runtag=subm.runtag)

