Hello all,
I've searched through examples and so far have seen examples on how to do to use one-hot-encoder only for model fitting or for evaluator, but can't figure out how to do this for the predict call. For example, we see use of one-hot as inputs to : 1. RF_MODEL = trainer.fit( <ignite>, <trainingcache>, // this has category column before one-hot split.getTrainFilter(), <one-hot-encoder> // this does one-hot inside the model - how do I get the cache with additional columns? ); OR ALSO here: 2. RegressionMetricValues regMetrics = Evaluator.evaluateRegression( <trainingcache>, split.getTestFilter(), <rf_model> <one-hot-encoder> ); But rfmodel.predict(Vector features) requires the original Vector with categorical columns be already converted into all doubles. What is best way to do this intermediate step. -- Sent from: http://apache-ignite-developers.2346864.n4.nabble.com/ |
Hi, Ken.
Currently, the Preprocessors are not independent operations yet (to train OHE separately and re-use the OHE for prediction). But if you have trained preprocessor (encoder) filled with encoding values after training process you could re-use it in prediction phase like in example below. The sequence of calls during the training is next: vectorizer->encoder->trainer The *encoderPreprocessor *contains statistics learnt during training phase and has special method *apply *that could take the key and value from CacheEntry Here, in loop, I call the apply method and it applies vectorizer and encoder (encoder contains link to vectorizer internally) to the Object[] in each CacheEntry value. If you run this example, you get the same accuracy in both evaluation sections. The initial example is located here: https://github.com/apache/ignite/blob/master/examples/src/main/java/org/apache/ignite/examples/ml/preprocessing/encoding/EncoderExample.java Also the trainer could be changed on Random Forest, for example. Hope it solves your problem. public class EncoderExample { /** * Run example. */ public static void main(String[] args) throws Exception { System.out.println(); System.out.println(">>> Train Decision Tree model on mushrooms.csv dataset."); try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) { try { IgniteCache<Integer, Object[]> dataCache = new SandboxMLCache(ignite) .fillObjectCacheWithDoubleLabels(MLSandboxDatasets.MUSHROOMS); final Vectorizer<Integer, Object[], Integer, Object> vectorizer = new ObjectArrayVectorizer<Integer>(1, 2, 3).labeled(0); Preprocessor<Integer, Object[]> encoderPreprocessor = new EncoderTrainer<Integer, Object[]>() .withEncoderType(EncoderType.STRING_ENCODER) .withEncodedFeature(0) .withEncodedFeature(1) .withEncodedFeature(2) .fit(ignite, dataCache, vectorizer ); DecisionTreeClassificationTrainer trainer = new DecisionTreeClassificationTrainer(5, 0); // Train decision tree model. DecisionTreeNode mdl = trainer.fit( ignite, dataCache, encoderPreprocessor ); System.out.println("\n>>> Trained model: " + mdl); int tp = 0; int total = 0; try (QueryCursor<Cache.Entry<Integer, Object[]>> observations = dataCache.query(new ScanQuery<>())) { for (Cache.Entry<Integer, Object[]> observation : observations) { Object[] val = observation.getValue(); double groundTruth = (double) val[0]; LabeledVector apply = encoderPreprocessor.apply(observation.getKey(), val); double prediction = mdl.predict(apply.features()); total++; if (prediction == groundTruth) tp++; System.out.printf(">>> | %.4f\t\t\t| %.4f\t\t|\n", prediction, groundTruth); } System.out.println(">>> ---------------------------------"); System.out.println("Manual accuracy is: " + ((double)tp)/total); System.out.println(">>> Decision Tree algorithm over cached dataset usage example completed."); } double accuracy = Evaluator.evaluate( dataCache, mdl, encoderPreprocessor, new Accuracy<>() ); System.out.println("\n>>> Accuracy " + accuracy); System.out.println("\n>>> Test Error " + (1 - accuracy)); System.out.println(">>> Train Decision Tree model on mushrooms.csv dataset."); } catch (FileNotFoundException e) { e.printStackTrace(); } } finally { System.out.flush(); } } } |
Free forum by Nabble | Edit this page |