/** * Author: Sofus A. Macskassy (smacskas@stern.nyu.edu) * Date: December 5, 2004 *

* This code is released under GPL. You are free to use * it for any purpose as long as credit is clearly given * to the author. * * This program implements the algorithms described in: * "Significance Testing against the Random Model for Scoring Models on Top k Predictions" * Sofus A. Macskassy, submitted to JAIR, 2004. */ import java.io.*; import java.util.ArrayList; import java.util.Arrays; import java.util.Date; import java.util.regex.Pattern; import java.text.NumberFormat; public class EvalTopK { private static class FilteredWriter { boolean quiet = false; PrintWriter pw; public FilteredWriter() { pw = new PrintWriter(System.out, true); } public FilteredWriter(PrintWriter pw) { this.pw = pw; } public void println(String s) { if(!quiet) pw.println(s); } public void println() { if(!quiet) pw.println(); } public void print(String s) { if(!quiet) pw.print(s); } public void print(int i) { if(!quiet) pw.print(i); } public void print(double d) { if(!quiet) pw.print(d); } public void close() { pw.close(); } public void flush() { pw.flush(); } } private static abstract class Printer { PrintWriter pw=null; double prevPos=0; double prevNeg=0; public abstract void print(double pos, double neg); public abstract void close(); public abstract void open(String stem) throws IOException; public Printer(String stem) throws IOException { open(stem); } } private static class ROCPrinter extends Printer { double rocPos; double rocNeg; double AUC; public ROCPrinter(String stem) throws IOException { super(stem); AUC = 0; } public void open(String stem) throws IOException { pw = new PrintWriter(new FileWriter(stem+".roc")); pw.println("# Created: "+new Date()); pw.println("# totpos="+totPos); pw.println("# totsize="+totSize); pw.println("0 0"); } public void print(double pos, double neg) { if (pos != rocPos && neg != rocNeg) { if(rocNeg != prevNeg) AUC += (rocNeg-prevNeg)*((rocPos+prevPos)/2); rocPos = prevPos; rocNeg = prevNeg; pw.println(prevNeg/totNeg+" "+prevPos/totPos); } prevPos = pos; prevNeg = neg; } public void close() { print(1,1); pw.println("# AUC: "+AUC); pw.close(); } } private static class PRPrinter extends Printer { double numPrint = 0; public PRPrinter(String stem) throws IOException { super(stem); } public void open(String stem) throws IOException { pw = new PrintWriter(new FileWriter(stem+".pr")); pw.println("# recall precision"); pw.println("0 1"); } public void print(double pos, double neg) { numPrint++; pw.println(prevPos/totPos+" "+prevPos/numPrint); prevPos = pos; prevNeg = neg; } public void close() { print(totPos,totSize); pw.close(); } } private static class ThresholdInfo implements Comparable { Printer roc = null; Printer pr = null; double cutoff=0; double threshold=0; int treeIndex=0; public ThresholdInfo(double threshold) { this.threshold = threshold; this.cutoff = 1-threshold; } public int compareTo(ThresholdInfo ti) { double diff = threshold-ti.threshold; return ( (diff<0) ? 1 : ((diff > 0) ? -1 : 0 ) ); } public void openFiles() throws IOException { if(rocStem != null) roc = new ROCPrinter(rocStem+"-"+threshold); if(prStem != null) pr = new PRPrinter(prStem+"-"+threshold); } public void print() throws IOException { if(roc != null) roc.print((int)graph[0][treeIndex].pos, (int)graph[0][treeIndex].neg); if(pr != null) pr.print((int)graph[0][treeIndex].pos, (int)graph[0][treeIndex].neg); } public void closeFiles() throws IOException { if(roc != null) roc.close(); if(pr != null) pr.close(); } } private static class EvalInfo implements Comparable { static final Pattern splitter = Pattern.compile(","); int k; double[] numP; public EvalInfo(String s) { String[] v = splitter.split(s); if(v.length<2) usage("eval parameter '"+s+"' invalid - it must contain at least one ','"); numP = new double[v.length-1]; k = Integer.parseInt(v[0]); if(k<1 || (ratio<0 && k > numData)) usage("eval parameter("+s+") has k="+k+" be larger than the data size("+numData+")!"); for(int i=1;i k || (ratio<0 && numP[i-1] > numPos)) usage("eval p-value of '"+numP[i-1]+"' is greater than k("+k+") or total number of positives("+numPos+")!"); } } public int compareTo(EvalInfo ei) { double diff = k-ei.k; return ( (diff<0) ? -1 : ((diff > 0) ? 1 : 0 ) ); } } private static class GraphNode { static NumberFormat rf = NumberFormat.getInstance(); static NumberFormat cf = NumberFormat.getInstance(); static { rf.setMaximumFractionDigits(4); cf.setMaximumFractionDigits(21); } double pos=0; double neg=0; double rki=0; double cumRKI=0; public String toString() { return "["+(int)pos+","+(int)neg+","+rf.format(rki)+","+cf.format(cumRKI)+"]"; } } private static class PredictNode implements Comparable { double score; // what is the score at this point boolean positive; // is this node positive int numP; // how many nodes with a score >= this score were positive. double threshold; // what is the threshold for this prediction public PredictNode(double score, boolean label) { this.score = score; this.positive = label; numP = 0; } public int compareTo(PredictNode pn) { double diff = score-pn.score; return ( (diff<0) ? 1 : ((diff > 0) ? -1 : 0 ) ); } } static String gfile = null; static PrintWriter graphOut = null; static GraphNode[][] graph; static ThresholdInfo[] thresholds = new ThresholdInfo[0]; static EvalInfo[] evals = new EvalInfo[0]; static String pfile = null; static PredictNode[] predictions = null; static String rocStem = null; static String prStem = null; static String distribStem = null; static FilteredWriter console = new FilteredWriter(); static String ofile = null; static FilteredWriter out = new FilteredWriter(); static boolean interpolate = false; static double ratio=-1; static int currK=0; static int minK=1; static int maxK=-1; static double totPos; static double totNeg; static double totSize; static int treeSize; static int numPos = -1; static int numData = -1; static int lag = -1; /** * print out the usage. Print extra help if no usage message is given. * @param msg message to print instead of verbose help. */ private static void usage(String msg) { console.quiet = false; console.println("Usage: eval_topK -h"); console.println("Usage: eval_topK [OPTIONS] -dp "); console.println("Usage: eval_topK [OPTIONS] (DATA-PARAM) p (pvalue)*"); console.println("Usage: eval_topK [OPTIONS] (DATA-PARAM) e ,[,,...] [,...]"); console.println(""); console.println("Description:"); console.println(" This program is based on the algorithm described in:"); console.println(" 'Evaluating Scoring Models on Top-K Predictions',"); console.println(" Sofus A. Macskassy, submitted to JAIR (2004)."); console.println(""); if(msg == null) { console.println("OPTIONS"); console.println(" -c Will calculate, for each threshold given, all the starting K"); console.println(" where the prediction file starts outperforming the given"); console.println(" threshold for more than consecutive K."); console.println(" You must use '-dp' and specify at least one threshold"); console.println(" when using this option."); console.println(" -g Will send graph to ."); console.println(" If no -t is given, the graph will not be printed."); console.println(" WARNING: use with caution. This will be a large file."); console.println(" -i Interpolate the values to estimate p and x_k."); console.println(" -mx Will calculate significance values up to (and including) K"); console.println(" -mn Will calculate significance values starting from K"); console.println(" -o Will send output to ."); console.println(" If no -o is given, all output goes to stdout."); console.println(" -pd Will write the probaiblity density function and cumulative"); console.println(" density function for each k with names:"); console.println(" ..pdf and ..cdf"); console.println(" -pp Will write a P/R curve for each threshold given with name:"); console.println(" ..pr"); console.println(" If a prediction file is given, then .predict.pr is"); console.println(" also written."); console.println(" -pr Will write an ROC curve for each threshold given with name:"); console.println(" ..roc"); console.println(" If a prediction file is given, then .predict.roc is"); console.println(" also written."); console.println(""); console.println("DATA PARAMS (must specify one, and only one, of these):"); console.println(" -dc Will set the data characteristics to be of infinite size,"); console.println(" with a given P+ probability of being '+'. You must specify"); console.println(" a maximum k (-mx) when using this setting."); console.println(" -dp Will read the given list of scores, with a 0 or 1 class label."); console.println(" Input line format:"); console.println(" <0|1|+|->"); console.println(" -ds

Will set the data characteristics to be of the given size (N)"); console.println(" with the given number of positives (P)"); console.println(" Output line format if -d is given:"); console.println(""); console.println("COMMANDS:"); console.println(" e ,[,...] [,,...]"); console.println(" For each K, will evaluate the significance of seeing each of"); console.println(" the P positives in the top K predictions. You must specify"); console.println(" at least one P per K. This will output one line per K, with"); console.println(" the list of threshold values."); console.println(" Output line-format is:"); console.println(" ,,,..."); console.println(" p pvalue [pvalue ....]"); console.println(" Will compute for each p-value and each K (min K -> max K)"); console.println(" the number of positives needed to reach that p-value."); console.println(" If -dp was used to set the data parameters, then it will also"); console.println(" compute the number of positives seen in the prediction file"); console.println(" for each K and their respective p-values on a per-K basis."); console.println(" Output line-format is:"); console.println(" ... "); console.println(" where and are for the given prediction"); console.println(" file at top-K and P-# is the number of positives"); console.println(" needed to be at that pvalue at that K"); console.println(""); } else { console.println(msg); } System.exit(0); } /** * compute the next row in the top-K tree. * @param size the number of elements in the current top-K row * @return the number of elements in the new top-K row */ private static int next(int size) { GraphNode[] swap; double levelSize = totSize - (graph[0][0].pos + graph[0][0].neg); graph[1][0].rki = 0; int j = 0; for (int i = 0; i < size && j < treeSize; i++) { if (ratio>0) { graph[1][j].pos = graph[0][i].pos; graph[1][j].neg = graph[0][i].neg + 1; graph[1][j].rki += graph[0][i].rki * (1 - ratio); j++; graph[1][j].rki = 0; graph[1][j].pos = graph[0][i].pos + 1; graph[1][j].neg = graph[0][i].neg; graph[1][j].rki += graph[0][i].rki * ratio; } else { if (graph[0][i].neg < totNeg) { graph[1][j].pos = graph[0][i].pos; graph[1][j].neg = graph[0][i].neg + 1; graph[1][j].rki += graph[0][i].rki * ((totNeg - graph[0][i].neg) / levelSize); j++; graph[1][j].rki = 0; } if (j < treeSize && graph[0][i].pos < totPos) { graph[1][j].pos = graph[0][i].pos + 1; graph[1][j].neg = graph[0][i].neg; graph[1][j].rki += graph[0][i].rki * ((totPos - graph[0][i].pos) / levelSize); } else { // we went beyond the last cell in tree[1][.] when adding a negative // i.e., there are no more positives, and we need to pull j back for // boundary safety. j--; } } } graph[1][0].cumRKI =graph[1][0].rki; for(int i=1;i<=j;i++) graph[1][i].cumRKI = graph[1][i-1].cumRKI + graph[1][i].rki; if(Math.abs(1-graph[1][j].cumRKI) > 0.00000001) System.err.println("ERROR[level="+currK+"] - cumRKI is too far from 1.0 ("+graph[1][j].cumRKI+"!"); if(graphOut != null && currK >= minK) { graphOut.print("level-"+currK+" ="); for(int i=0;i<=j;i++) { graphOut.print(" "); graphOut.print(graph[1][i]); } graphOut.println(); } swap = graph[0]; graph[0] = graph[1]; graph[1] = swap; return (j + 1); } /** * Read a prediction file, consisting of scores and a +/- label. Use this to set the * data characteristics (e.g., N=dataset-set, P=num-positives, p=P/N). * @param f the file to read from * @return a sorted array of PredictNode objects, sorted on their score. * @throws IOException If there was an error in reading the given file. */ private static PredictNode[] readPredictionFile(File f) throws IOException { ArrayList p = new ArrayList(); Pattern splitter = Pattern.compile(" "); LineNumberReader lr = new LineNumberReader(new FileReader(f)); for(String s = lr.readLine(); s != null; s = lr.readLine()) { String[] d = splitter.split(s); if(d.length != 2 || d[1].length() > 1) throw new IOException("invalid input("+s+") at line "+lr.getLineNumber()+". Each line must be of format ' (0|1|+|-)>'"); boolean label = false; switch(d[1].charAt(0)) { case '0': case '-': label = false; break; case '1': case '+': label = true; break; default: throw new IOException("invalid label("+d[1]+") at line "+lr.getLineNumber()+". Each line must be of format ' (0|1)>'"); } double score = Double.parseDouble(d[0]); p.add(new PredictNode(score,label)); } lr.close(); PredictNode[] pa = p.toArray(new PredictNode[0]); Arrays.sort(pa); int np = 0; for(PredictNode pn : pa) { if(pn.positive) np++; pn.numP = np; } return pa; } /** * set up various varaiables after having read the arguments. this includes setting * up (properly) the size of the data, output files, etc. * @throws IOException */ private static void setup() throws IOException { totPos = (double)numPos; totSize = (double)numData; totNeg = totSize - totPos; if(maxK<0) maxK = numData; console.println("numPos="+numPos); console.println("size ="+numData); if(ofile!=null) { console.println("output file="+ofile); out = new FilteredWriter(new PrintWriter(new FileWriter(ofile))); } if(gfile!=null) { console.println("print tree to file="+gfile); console.println(" WARNING: this will be a large file!"); graphOut = new PrintWriter(new FileWriter(gfile)); } if(thresholds.length > 0) { console.print("thresholds ="); for(ThresholdInfo ti : thresholds) { ti.openFiles(); console.print(" "+ti.threshold+"["+ti.cutoff+"]"); } console.println(); } if(evals.length > 0) { console.print("eval ="); for(EvalInfo ei : evals) { console.print(" "+ei.k); for(double np : ei.numP) console.print(","+np); } console.println(); minK = evals[0].k; maxK = evals[evals.length-1].k; } console.println("minK ="+minK); console.println("maxK ="+maxK); if(lag!=-1) console.println("using lag(-c)="+lag); if(prStem!=null) console.println("output P/R curves with stem: "+prStem); if(rocStem!=null)console.println("output ROC curves with stem: "+rocStem); // If known data size, then maximum tree level size is (numData/2)+1 treeSize = Math.min(((ratio>0) ? (maxK+1) : (1 + ((1 + numData) / 2))),maxK+1); graph = new GraphNode[2][treeSize]; for(int i=0;i2 ? Character.toLowerCase(args[i].charAt(2)) : '\0'); if(opt != 'q') i++; if(i == args.length) usage("Missing parameter(s) for option '"+args[i]+"'"); switch(opt) { case 'c': lag = Integer.parseInt(args[i]); if(lag < 1) usage("Invalid value of "+lag+" for '-c'. It must be positive"); break; case 'd': if(data != 0) usage("Parameter("+args[i-1]+") - You already specified -d"+data+". You cannot specify more than one -d parameter."); data = opt2; switch(opt2) { case 'c': ratio = Double.parseDouble(args[i]); if(ratio<=0 || ratio>=1) usage("invalid -dc value of '"+ratio+"'. It must lie in the range (0,1) [exclusive]."); numData = Integer.MAX_VALUE; numPos = -1; break; case 'p': pfile = args[i]; break; case 's': if(i+1 == args.length) usage("Missing parameter(s) for option '"+args[i]+"'"); numPos = Integer.parseInt(args[i]); numData = Integer.parseInt(args[i+1]); if(numPos < 1 || numData <= numPos) usage("Invalid values for '-dp'. numPos="+numPos+" size="+numData+". Both must be positive and numPos0 && (data != 'p' || i==args.length || Character.toLowerCase(args[i].charAt(0)) != 't')) usage("You must use -dp and use command 'p' when using -c"); if(ratio > 0 && (rocStem != null || prStem != null)) usage("You cannot use -pr or -pp with -dc as we don't have a finite size nor a known number of positives."); if(pfile != null) { console.println("reading predictions from "+pfile); predictions = readPredictionFile(new File(pfile)); numData = predictions.length; numPos = predictions[predictions.length-1].numP; } if(i= 1) usage("Invalid of "+p+". It must be between 0 and 1!\n"); thresholds[j] = new ThresholdInfo(1-p); i++; j++; } Arrays.sort(thresholds); break; default: usage("unexpected token: "+args[i]); break; } } } /** * get the CDF of the hypergeometric distribution. * @param numPos how many positives were seen (currK, totSize, totPos are assumed to have been set) * @return the CDF */ private static double getCDF(double numPos) { double outP = Double.NaN; int smallP = (int)graph[0][0].pos; int ci = (int)numPos-smallP; outP = graph[0][ci].cumRKI; if(interpolate && graph[0][ci].pos != numPos) outP += (numPos-graph[0][ci].pos)*(graph[0][ci+1].cumRKI-outP); return outP; } /** * main program. it gets the args, sets up variables and does what needs to be done. * @param args * @throws IOException */ public static void main(String[] args) throws IOException { parseArgs(args); setup(); // set up how often to print a status line. At least once every 1000 lines, // or 250 times throughout the run. int mod = Math.min(numData/250,1000); int size = 1; // if thresholds are to be printed, print this header first if(thresholds.length > 0) { out.print("# K"); if(predictions != null) out.print(" predFileNumP predFileThreshold"); for(ThresholdInfo ti : thresholds) out.print(" numP("+ti.threshold+")"); out.println(); } currK=1; int evalK = 0; // advance to the first K to evaluate on. Since the algorithm works // be growing one k at a time, we still need to do this. for (;currK < minK; currK++) size = next(size); // loop until we've reached the max K specified for (; currK <= maxK; currK++) { // advance to the next level size = next(size); if(distribStem != null) { PrintWriter pwc = new PrintWriter(new FileWriter(distribStem+"."+currK+".cdf")); PrintWriter pwd = new PrintWriter(new FileWriter(distribStem+"."+currK+".pdf")); pwc.println("# "); pwd.println("# "); for(int i=0;i 0) { // print: // [ ] thresholds[0].numP... thresholds[T].numP out.print(currK); // print stats in the prediction file for this top K if(predictions != null) { out.print(" "); out.print(predictions[currK-1].numP); out.print(" "); double outP = getCDF(predictions[currK-1].numP); out.print(outP); predictions[currK-1].threshold = outP; } int j=size-1; // print stats for each threshold given for(ThresholdInfo ti : thresholds) { // let's find the treeIndex for the threshold while(j>0 && graph[0][j].cumRKI > ti.threshold) j--; ti.treeIndex = j; out.print(" "); double np = graph[0][ti.treeIndex].pos; if(interpolate) np += (ti.threshold-graph[0][ti.treeIndex].cumRKI)/(graph[0][ti.treeIndex+1].cumRKI-graph[0][ti.treeIndex].cumRKI); else if(graph[0][ti.treeIndex].cumRKI != ti.threshold) np++; out.print(np); ti.print(); } out.println(); } // print a ststus line if(numData < 250 || (currK%mod)==0) { console.print(currK+" "+size+" \r"); console.flush(); } } // close all output files (ROC + P/R curves) for(ThresholdInfo ti : thresholds) ti.closeFiles(); console.println(maxK+" "+size+" "); // now let's find out, if user-specified, were the given prediction file // crosses each given threshold. It must be as good as the threshold at // least 'lag' consecutive k's before being counted. if(lag>0) { console.println("Finding runs in prediction file where it outperforms each threshold"); for(ThresholdInfo ti : thresholds) { out.print("Exceed-Threshold("+ti.threshold+"):"); int start=-1; int run=0; for(int i=0;i= ti.threshold) { run++; if(start < 0) start=i; } else { if(run>=lag) { out.print(" [k="+(start+1)+",+="+predictions[start].numP+"]"); out.print("-[k="+i+",+="+predictions[i-1].numP+"]"); } start = -1; run = 0; } } if(run>=lag) { out.print(" [k="+(start+1)+",+="+predictions[start].numP+"]"); out.print("-[k="+maxK+",+="+predictions[maxK-1].numP+"]"); } out.println(); } } console.println(maxK+" "+size+" "); console.println("DONE!"); if(graphOut != null) graphOut.close(); out.close(); console.close(); } }