diff --git a/psl-cli/src/main/java/org/linqs/psl/cli/Launcher.java b/psl-cli/src/main/java/org/linqs/psl/cli/Launcher.java index 0d4e9a315..62bc39f33 100644 --- a/psl-cli/src/main/java/org/linqs/psl/cli/Launcher.java +++ b/psl-cli/src/main/java/org/linqs/psl/cli/Launcher.java @@ -18,8 +18,11 @@ package org.linqs.psl.cli; import org.linqs.psl.application.inference.InferenceApplication; +import org.linqs.psl.application.learning.weight.TrainingMap; import org.linqs.psl.application.learning.weight.WeightLearningApplication; import org.linqs.psl.application.learning.weight.maxlikelihood.MaxLikelihoodMPE; +import org.linqs.psl.config.Options; +import org.linqs.psl.database.atom.PersistedAtomManager; import org.linqs.psl.database.DataStore; import org.linqs.psl.database.Database; import org.linqs.psl.database.Partition; @@ -32,6 +35,8 @@ import org.linqs.psl.grounding.GroundRuleStore; import org.linqs.psl.model.Model; import org.linqs.psl.model.atom.GroundAtom; +import org.linqs.psl.model.atom.ObservedAtom; +import org.linqs.psl.model.atom.RandomVariableAtom; import org.linqs.psl.model.predicate.StandardPredicate; import org.linqs.psl.model.rule.GroundRule; import org.linqs.psl.model.rule.Rule; @@ -40,6 +45,7 @@ import org.linqs.psl.model.term.Constant; import org.linqs.psl.parser.ModelLoader; import org.linqs.psl.parser.CommandLineLoader; +import org.linqs.psl.util.ModelDataCollector; import org.linqs.psl.util.Reflection; import org.linqs.psl.util.StringUtils; import org.linqs.psl.util.Version; @@ -63,7 +69,6 @@ import java.nio.file.Paths; import java.util.List; import java.util.Map; - import java.util.Set; import java.util.regex.Pattern; @@ -211,6 +216,10 @@ private Database runInference(Model model, DataStore dataStore, Set closedPredicates) { + Set openPredicates = dataStore.getRegisteredPredicates(); + openPredicates.removeAll(closedPredicates); + + // Create database. + Partition targetPartition = dataStore.getPartition(PARTITION_NAME_TARGET); + Partition observationsPartition = dataStore.getPartition(PARTITION_NAME_OBSERVATIONS); + Partition truthPartition = dataStore.getPartition(PARTITION_NAME_LABELS); + + boolean closePredictionDB = false; + if (predictionDatabase == null) { + closePredictionDB = true; + predictionDatabase = dataStore.getDatabase(targetPartition, closedPredicates, observationsPartition); + } + + Database truthDatabase = dataStore.getDatabase(truthPartition, dataStore.getRegisteredPredicates()); + + // Create TrainingMap between predictions and truth. + PersistedAtomManager atomManager = new PersistedAtomManager(predictionDatabase, !closePredictionDB); + TrainingMap trainingMap = new TrainingMap(atomManager, truthDatabase); + + // Collect prediction and truth values for each target. + for (Map.Entry entry : trainingMap.getLabelMap().entrySet()) { + ModelDataCollector.addTruth(entry.getKey(), entry.getValue().getValue()); + } + + if (closePredictionDB) { + predictionDatabase.close(); + } + truthDatabase.close(); + + ModelDataCollector.dissatisfactionPerGroundRule(inferenceApplication.getGroundRuleStore()); + } + private void outputResults(Database database, DataStore dataStore, Set closedPredicates) { // Set of open predicates Set openPredicates = dataStore.getRegisteredPredicates(); @@ -396,6 +440,11 @@ private void run() { // Load model Model model = loadModel(dataStore); + if (parsedOptions.hasOption(CommandLineLoader.OPTION_MODEL_DATA_COLLECTION)) { + Options.CLI_MODEL_DATA_COLLECTION.set(true); + ModelDataCollector.setOutputPath(parsedOptions.getOptionValue(CommandLineLoader.OPTION_MODEL_DATA_COLLECTION)); + } + // Inference Database evalDB = null; if (parsedOptions.hasOption(CommandLineLoader.OPERATION_INFER)) { diff --git a/psl-cli/src/test/java/org/linqs/psl/cli/SimpleAcquaintancesTest.java b/psl-cli/src/test/java/org/linqs/psl/cli/SimpleAcquaintancesTest.java index e790ecaa5..fc1e922be 100644 --- a/psl-cli/src/test/java/org/linqs/psl/cli/SimpleAcquaintancesTest.java +++ b/psl-cli/src/test/java/org/linqs/psl/cli/SimpleAcquaintancesTest.java @@ -30,6 +30,9 @@ import org.junit.Test; +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; import java.nio.file.Paths; import java.util.Arrays; import java.util.List; @@ -153,6 +156,27 @@ public void testErrorObservedTargets() { } } + @Test + public void testVisualization() { + String modelPath = Paths.get(baseModelsDir, "simple-acquaintances.psl").toString(); + String dataPath = Paths.get(baseDataDir, "simple-acquaintances", "base.data").toString(); + + Path tempOutputFile = Paths.get(System.getProperty("java.io.tmpdir"), "model-data.gz"); + + List additionalArgs = Arrays.asList( + "--" + CommandLineLoader.OPTION_MODEL_DATA_COLLECTION_LONG, + tempOutputFile.toString() + ); + + run(modelPath, dataPath, additionalArgs); + + try { + Files.delete(tempOutputFile); + } catch (IOException ex) { + // Not expected. + } + } + // Not an actual similarity. public static class SimNameExternalFunction implements ExternalFunction { @Override diff --git a/psl-core/src/main/java/org/linqs/psl/config/Options.java b/psl-core/src/main/java/org/linqs/psl/config/Options.java index f7bab52d5..24d3c9148 100644 --- a/psl-core/src/main/java/org/linqs/psl/config/Options.java +++ b/psl-core/src/main/java/org/linqs/psl/config/Options.java @@ -214,6 +214,12 @@ public class Options { Option.FLAG_POSITIVE ); + public static final Option CLI_MODEL_DATA_COLLECTION = new Option( + "cli.mode.data.collection", + false, + "Include visualization data collection." + ); + public static final Option EVAL_CLOSE_TRUTH = new Option( "eval.closetruth", false, diff --git a/psl-core/src/main/java/org/linqs/psl/model/rule/arithmetic/AbstractArithmeticRule.java b/psl-core/src/main/java/org/linqs/psl/model/rule/arithmetic/AbstractArithmeticRule.java index 4aaed86d8..569a7675e 100644 --- a/psl-core/src/main/java/org/linqs/psl/model/rule/arithmetic/AbstractArithmeticRule.java +++ b/psl-core/src/main/java/org/linqs/psl/model/rule/arithmetic/AbstractArithmeticRule.java @@ -17,6 +17,7 @@ */ package org.linqs.psl.model.rule.arithmetic; +import org.linqs.psl.config.Options; import org.linqs.psl.database.DatabaseQuery; import org.linqs.psl.database.ResultList; import org.linqs.psl.database.atom.AtomManager; @@ -55,6 +56,7 @@ import org.linqs.psl.model.term.Variable; import org.linqs.psl.model.term.VariableTypeMap; import org.linqs.psl.reasoner.function.FunctionComparator; +import org.linqs.psl.util.ModelDataCollector; import org.linqs.psl.util.Parallel; import com.healthmarketscience.sqlbuilder.BinaryCondition; @@ -280,6 +282,13 @@ public void ground(Constant[] constants, Map variableMap, Ato } else { groundForSummation(constants, variableMap, atomManager, results); } + + Boolean collectDataOption = (Boolean)Options.CLI_MODEL_DATA_COLLECTION.getUnlogged(); + if (collectDataOption != null && collectDataOption.booleanValue()) { + for (GroundRule groundRule : results) { + ModelDataCollector.addGroundRule(this, groundRule, variableMap, constants, atomManager); + } + } } private void groundForNonSummation(Constant[] constants, Map variableMap, AtomManager atomManager, @@ -336,8 +345,18 @@ private int groundAllNonSummationRule(AtomManager atomManager, GroundRuleStore g ResultList results = atomManager.executeQuery(new DatabaseQuery(expression.getQueryFormula(), false)); Map variableMap = results.getVariableMap(); + int priorResourcesSize = resources.groundRules.size(); for (int groundingIndex = 0; groundingIndex < results.size(); groundingIndex++) { groundSingleNonSummationRule(results.get(groundingIndex), variableMap, atomManager, resources); + int postGroundingResourcesSize = resources.groundRules.size(); + // Checking the size of the resources allows us to verify if a grounding occured or not. + if (resources.collectData) { + if (postGroundingResourcesSize != priorResourcesSize) { + GroundRule groundRule = resources.groundRules.get(resources.groundRules.size()-1); + ModelDataCollector.addGroundRule(this, groundRule, variableMap, results.get(groundingIndex), atomManager); + } + } + priorResourcesSize = resources.groundRules.size(); } int count = resources.groundRules.size(); @@ -410,8 +429,18 @@ private int groundAllSummationRule(AtomManager atomManager, GroundRuleStore grou ResultList results = database.executeQuery(rawQuery); Map variableMap = results.getVariableMap(); + int priorResourcesSize = resources.groundRules.size(); for (int groundingIndex = 0; groundingIndex < results.size(); groundingIndex++) { groundSingleSummationRule(results.get(groundingIndex), variableMap, atomManager, resources); + int postGroundingResourcesSize = resources.groundRules.size(); + // Checking the size of the resources allows us to verify if a grounding occured or not. + if (resources.collectData) { + if (postGroundingResourcesSize != priorResourcesSize) { + GroundRule groundRule = resources.groundRules.get(resources.groundRules.size()-1); + ModelDataCollector.addGroundRule(this, groundRule, variableMap, results.get(groundingIndex), atomManager); + } + } + priorResourcesSize = resources.groundRules.size(); } int count = resources.groundRules.size(); @@ -949,6 +978,8 @@ private static class GroundingResources { // Atoms that cause trouble for the atom manager. public Set accessExceptionAtoms; + public boolean collectData; + // Shared resources. public List queryAtoms; @@ -981,6 +1012,9 @@ private static class GroundingResources { public GroundingResources() { groundRules = new ArrayList(); accessExceptionAtoms = new HashSet(4); + + Boolean collectDataOption = (Boolean)Options.CLI_MODEL_DATA_COLLECTION.getUnlogged(); + collectData = (collectDataOption != null && collectDataOption.booleanValue()); } public void parseExpression(ArithmeticRuleExpression expression, boolean computeCoefficients) { diff --git a/psl-core/src/main/java/org/linqs/psl/model/rule/logical/AbstractLogicalRule.java b/psl-core/src/main/java/org/linqs/psl/model/rule/logical/AbstractLogicalRule.java index 9508cbef1..12cd4ca98 100644 --- a/psl-core/src/main/java/org/linqs/psl/model/rule/logical/AbstractLogicalRule.java +++ b/psl-core/src/main/java/org/linqs/psl/model/rule/logical/AbstractLogicalRule.java @@ -17,6 +17,7 @@ */ package org.linqs.psl.model.rule.logical; +import org.linqs.psl.config.Options; import org.linqs.psl.database.DatabaseQuery; import org.linqs.psl.database.QueryResultIterable; import org.linqs.psl.database.atom.AtomManager; @@ -41,6 +42,7 @@ import org.linqs.psl.reasoner.function.GeneralFunction; import org.linqs.psl.util.HashCode; import org.linqs.psl.util.MathUtils; +import org.linqs.psl.util.ModelDataCollector; import org.linqs.psl.util.Parallel; import org.linqs.psl.util.StringUtils; @@ -164,7 +166,12 @@ private GroundRule ground(Constant[] constants, Map variableM } GroundingResources resources = (GroundingResources)Parallel.getThreadObject(groundingResourcesKey); - return groundInternal(constants, variableMap, atomManager, resources); + GroundRule groundRule = groundInternal(constants, variableMap, atomManager, resources); + if (groundRule != null && resources.collectData) { + ModelDataCollector.addGroundRule(this, groundRule, variableMap, constants, atomManager); + } + + return groundRule; } public int groundAll(QueryResultIterable groundVariables, AtomManager atomManager, GroundRuleStore groundRuleStore) { @@ -339,6 +346,8 @@ private static class GroundingResources { // Atoms that cause trouble for the atom manager. public Set accessExceptionAtoms; + public boolean collectData; + // Allocate up-front some buffers for grounding QueryAtoms into. public Constant[][] positiveAtomArgs; public Constant[][] negativeAtomArgs; @@ -348,6 +357,9 @@ public GroundingResources(DNFClause negatedDNF) { negativeAtoms = new ArrayList(4); accessExceptionAtoms = new HashSet(4); + Boolean collectDataOption = (Boolean)Options.CLI_MODEL_DATA_COLLECTION.getUnlogged(); + collectData = (collectDataOption != null && collectDataOption.booleanValue()); + int numLiterals = negatedDNF.getPosLiterals().size() + negatedDNF.getNegLiterals().size(); positiveAtomArgs = new Constant[negatedDNF.getPosLiterals().size()][]; diff --git a/psl-core/src/main/java/org/linqs/psl/util/ModelDataCollector.java b/psl-core/src/main/java/org/linqs/psl/util/ModelDataCollector.java new file mode 100644 index 000000000..ff88b121b --- /dev/null +++ b/psl-core/src/main/java/org/linqs/psl/util/ModelDataCollector.java @@ -0,0 +1,343 @@ +package org.linqs.psl.util; + +import org.linqs.psl.database.atom.AtomManager; +import org.linqs.psl.grounding.GroundRuleStore; +import org.linqs.psl.model.atom.Atom; +import org.linqs.psl.model.atom.GroundAtom; +import org.linqs.psl.model.atom.QueryAtom; +import org.linqs.psl.model.formula.AbstractBranchFormula; +import org.linqs.psl.model.formula.Conjunction; +import org.linqs.psl.model.formula.Disjunction; +import org.linqs.psl.model.formula.Formula; +import org.linqs.psl.model.formula.Implication; +import org.linqs.psl.model.formula.Negation; +import org.linqs.psl.model.predicate.Predicate; +import org.linqs.psl.model.predicate.StandardPredicate; +import org.linqs.psl.model.rule.AbstractRule; +import org.linqs.psl.model.rule.arithmetic.AbstractGroundArithmeticRule; +import org.linqs.psl.model.rule.GroundRule; +import org.linqs.psl.model.rule.logical.AbstractLogicalRule; +import org.linqs.psl.model.rule.Rule; +import org.linqs.psl.model.rule.UnweightedGroundRule; +import org.linqs.psl.model.rule.WeightedGroundRule; +import org.linqs.psl.model.rule.WeightedRule; +import org.linqs.psl.model.term.Constant; +import org.linqs.psl.model.term.Term; +import org.linqs.psl.model.term.Variable; + +import org.json.JSONObject; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.FilterOutputStream; +import java.io.IOException; +import java.io.PrintStream; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.zip.GZIPOutputStream; + +public class ModelDataCollector { + private static final Logger log = LoggerFactory.getLogger(ModelDataCollector.class); + + private static Runtime runtime = null; + private static ModelData modelData = null; + + private static String outputPath = null; + + static { + init(); + } + + private ModelDataCollector() {} + + private static synchronized void init() { + if (runtime != null) { + return; + } + + modelData = new ModelData(); + runtime = Runtime.getRuntime(); + runtime.addShutdownHook(new ShutdownHook()); + } + + public static void outputJSON() { + + try { + if (outputPath == null) { + throw new RuntimeException(); + } + + GZIPOutputStream stream = new GZIPOutputStream(new PrintStream(outputPath)); + writeToStream(stream); + stream.close(); + } catch (IOException ex) { + if (outputPath == null) { + throw new RuntimeException("Path not specified for output file."); + } else { + throw new RuntimeException("Could not write to path: " + outputPath, ex); + } + } + } + + /** + * Write to stream with JSON formatting. + */ + private static void writeToStream(FilterOutputStream stream) throws IOException { + // Write each map as a JSON object, each JSON object is comma delimited. + stream.write('{'); + + stream.write("\"truthMap\":".getBytes()); + writeMap(stream, modelData.truthMap); + stream.write(','); + + stream.write("\"rules\":".getBytes()); + writeMap(stream, modelData.rules); + stream.write(','); + + stream.write("\"groundRules\":".getBytes()); + writeMap(stream, modelData.groundRules); + stream.write(','); + + stream.write("\"groundAtoms\":".getBytes()); + writeMap(stream, modelData.groundAtoms); + + stream.write('}'); + } + + /** + * Write map to stream with JSON formatting. + * Assumption for optimizing write buffer: Map values use small amount of memory + */ + private static void writeMap(FilterOutputStream stream, Object map) throws IOException { + stream.write('{'); + + @SuppressWarnings("unchecked") + Map stringObjMap = (Map)map; + Iterator> iterator = stringObjMap.entrySet().iterator(); + while (iterator.hasNext()) { + Map.Entry entry = iterator.next(); + stream.write(("\"" + entry.getKey() + "\":").getBytes()); + + // Values of the map will either be a Float or Map. + if (entry.getValue() instanceof Float) { + stream.write(entry.getValue().toString().getBytes()); + } else { + @SuppressWarnings("unchecked") + Map data = (Map)entry.getValue(); + JSONObject jsonObject = new JSONObject(data); + stream.write(jsonObject.toString().getBytes()); + } + + if (iterator.hasNext()) { + stream.write(','); + } + } + + stream.write('}'); + } + + private static class ShutdownHook extends Thread { + @Override + public void run() { + outputJSON(); + } + } + + public static class ModelData { + public Map truthMap; + public Map> rules; + public Map> groundRules; + public Map> groundAtoms; + + public ModelData() { + truthMap = new HashMap(); + rules = new HashMap>(); + groundRules = new HashMap>(); + groundAtoms = new HashMap>(); + } + } + + public static void setOutputPath(String path) { + outputPath = path; + } + + // Takes in a prediction truth pair and adds it to the Truth Map. + public static void addTruth(GroundAtom target, float truthVal ) { + String groundAtomID = Integer.toString(System.identityHashCode(target)); + modelData.truthMap.put(groundAtomID, truthVal); + } + + // Adds a ground rules dissatisfaction / infeasibility to the groundRules HashMap. + public static void dissatisfactionPerGroundRule(GroundRuleStore groundRuleStore) { + for (GroundRule groundRule : groundRuleStore.getGroundRules()) { + String strGroundRuleId = Integer.toString(System.identityHashCode(groundRule)); + Map groundRuleObj = modelData.groundRules.get(strGroundRuleId); + if (groundRule instanceof WeightedGroundRule) { + WeightedGroundRule weightedGroundRule = (WeightedGroundRule) groundRule; + groundRuleObj.put("dissatisfaction", weightedGroundRule.getIncompatibility()); + } else { + UnweightedGroundRule unweightedGroundRule = (UnweightedGroundRule) groundRule; + groundRuleObj.put("dissatisfaction", unweightedGroundRule.getInfeasibility()); + } + } + } + + // Decorates an entry in a Formula to have constants or negations marked. + public static Object decorateFormula(GroundAtom groundAtom, GroundRule groundRule, boolean negation) { + if (negation) { + return new Integer[] {System.identityHashCode(groundAtom), 1}; + } + + return System.identityHashCode(groundAtom); + } + + // Parses through an atom object, gets ground atoms from AtomManager via predicate and arguments. + public static GroundAtom parseAtom (Formula formula, Map varConstMap, AtomManager atomManager) { + Atom atom = (Atom)formula; + Predicate predicate = atom.getPredicate(); + Term[] arguments = atom.getArguments(); + + Constant[] constants = new Constant[2]; // Atoms will have 2 constants each, so can used a fixed array size to pass into atomManager.getAtom() + int i = 0; + for (Term t : arguments){ + if (!(t instanceof Variable)) { + continue; + } + constants[i] = varConstMap.get(t); + i++; + } + if (predicate instanceof StandardPredicate) { + return atomManager.getAtom((StandardPredicate)predicate, constants); + } + + return null; + } + + // Parses through a formula object, and creates the needed object for modelData object consumption. + public static List parseFormula(Formula formula, Map varConstMap, GroundRule groundRule, boolean negation, AtomManager atomManager) { + ArrayList groundAtoms = new ArrayList(); + if (formula instanceof QueryAtom){ + GroundAtom groundedAtom = parseAtom(formula, varConstMap, atomManager); + if (groundedAtom != null) { + Object decoratedFormula = decorateFormula(groundedAtom, groundRule, negation); + if (decoratedFormula != null) { + groundAtoms.add(decoratedFormula); + } + } + } else { + AbstractBranchFormula branchFormula = (AbstractBranchFormula) formula; + for (int i = 0; i < branchFormula.length(); i++){ + GroundAtom groundedAtom = parseAtom(branchFormula.get(i), varConstMap, atomManager); + if (groundedAtom != null) { + Object decoratedFormula = decorateFormula(groundedAtom, groundRule, negation); + if (decoratedFormula != null) { + groundAtoms.add(decoratedFormula); + } + } + } + } + return groundAtoms; + } + + public static synchronized void addGroundRule(AbstractRule parentRule, + GroundRule groundRule, Map variableMap, Constant[] constantsList, AtomManager atomManager) { + if (groundRule == null) { + return; + } + + // Create the variable constant map used for replacement. + Map varConstMap = new HashMap(); + for (Map.Entry entry : variableMap.entrySet()) { + varConstMap.put(entry.getKey(), constantsList[entry.getValue()]); + } + + // Captures the lhs and rhs of a ground rule in order to add to the modelData object. + String groundRuleString; + List lhs = new ArrayList(); + List rhs = new ArrayList(); + String operator = ""; + if (parentRule instanceof AbstractLogicalRule) { + AbstractLogicalRule abstractLogicalParent = (AbstractLogicalRule) parentRule; + Formula formula = abstractLogicalParent.getFormula(); + boolean negationFlag = false; + if (formula instanceof Implication) { + operator = ">>"; + Implication implication = (Implication) formula; + Formula body = implication.getBody(); + Formula head = implication.getHead(); + if (body instanceof Negation) { + Negation negation = (Negation) body; + body = negation.getFormula(); + negationFlag = true; + } + lhs = parseFormula(body, varConstMap, groundRule, negationFlag, atomManager); + + if (head instanceof Negation) { + Negation negation = (Negation) head; + head = negation.getFormula(); + negationFlag = true; + } + rhs = parseFormula(head, varConstMap, groundRule, negationFlag, atomManager); + } else if (formula instanceof Conjunction || formula instanceof Disjunction){ + lhs = parseFormula(formula, varConstMap, groundRule, negationFlag, atomManager); + } else if (formula instanceof Negation){ + Negation negation = (Negation) formula; + Formula negationFormula = negation.getFormula(); + negationFlag = true; + lhs = parseFormula(negationFormula, varConstMap, groundRule, negationFlag, atomManager); + } + } else { + AbstractGroundArithmeticRule abstractArithmetic = (AbstractGroundArithmeticRule) groundRule; + GroundAtom[] orderedAtoms = abstractArithmetic.getOrderedAtoms(); + float[] coefficients = abstractArithmetic.getCoefficients(); + + for (int i = 0; i < orderedAtoms.length; i++) { + if (i < coefficients.length) { + Object[] atomObject = {System.identityHashCode(orderedAtoms[i]), coefficients[i]}; + lhs.add(atomObject); + } else { + lhs.add(System.identityHashCode(orderedAtoms[i])); + } + } + operator = abstractArithmetic.getComparator().toString(); + } + + // Adds a groundAtom element to the modelData object. + ArrayList atomHashList = new ArrayList(); + HashSet atomSet = new HashSet(groundRule.getAtoms()); + int atomCount = 0; + for (GroundAtom groundAtom : atomSet) { + atomHashList.add(System.identityHashCode(groundAtom)); + Map groundAtomElement = new HashMap(); + groundAtomElement.put("text", groundAtom.toString()); + groundAtomElement.put("prediction", groundAtom.getValue()); + modelData.groundAtoms.put(Integer.toString(System.identityHashCode(groundAtom)), groundAtomElement); + atomCount++; + } + + // Adds a rule element to the modelData object. + String ruleStringID = Integer.toString(System.identityHashCode(parentRule)); + Map rulesElementItem = new HashMap(); + rulesElementItem.put("text", parentRule.getName()); + if (parentRule instanceof WeightedRule) { + WeightedRule weightedParentRule = (WeightedRule) parentRule; + rulesElementItem.put("weighted", weightedParentRule.getWeight()); + } else { + rulesElementItem.put("weighted", null); + } + modelData.rules.put(ruleStringID, rulesElementItem); + + // Adds a groundRule element to the modelData object. + Map groundRulesElement = new HashMap(); + groundRulesElement.put("ruleID", Integer.parseInt(ruleStringID)); + groundRulesElement.put("lhs", lhs); + groundRulesElement.put("rhs", rhs); + groundRulesElement.put("operator", operator); + String groundRuleStringID = Integer.toString(System.identityHashCode(groundRule)); + modelData.groundRules.put(groundRuleStringID, groundRulesElement); + } +} diff --git a/psl-parser/src/main/java/org/linqs/psl/parser/CommandLineLoader.java b/psl-parser/src/main/java/org/linqs/psl/parser/CommandLineLoader.java index 0ee29bd07..a978080f2 100644 --- a/psl-parser/src/main/java/org/linqs/psl/parser/CommandLineLoader.java +++ b/psl-parser/src/main/java/org/linqs/psl/parser/CommandLineLoader.java @@ -64,6 +64,8 @@ public class CommandLineLoader { public static final String OPTION_DB_POSTGRESQL_NAME = "postgres"; public static final String OPTION_EVAL = "e"; public static final String OPTION_EVAL_LONG = "eval"; + public static final String OPTION_MODEL_DATA_COLLECTION = "viz"; + public static final String OPTION_MODEL_DATA_COLLECTION_LONG = "visualization"; public static final String OPTION_INT_IDS = "int"; public static final String OPTION_INT_IDS_LONG = "int-ids"; public static final String OPTION_LOG4J = "4j"; @@ -263,6 +265,15 @@ private static Options setupOptions() { .argName("evaluator ...") .build()); + newOptions.addOption(Option.builder(OPTION_MODEL_DATA_COLLECTION) + .longOpt(OPTION_MODEL_DATA_COLLECTION_LONG) + .desc("Gather data for creating a visualization of a given PSL run." + + " When a path is specified, the visualization will be output there.") + .hasArg() + .argName("path") + .optionalArg(false) + .build()); + newOptions.addOption(Option.builder(OPTION_INT_IDS) .longOpt(OPTION_INT_IDS_LONG) .desc("Use integer identifiers (UniqueIntID) instead of string identifiers (UniqueStringID).")