using System; using Myriax.Eonfusion.API; using Myriax.Eonfusion.API.Binding; using Myriax.Eonfusion.API.Data; using Myriax.Eonfusion.API.Properties; using Myriax.Eonfusion.API.UI.Descriptors; using System.Collections.Generic; namespace rForest { public class AddIn : IAddIn { public IDataset Execute(IApplication application) { inputVS inputVSG = null; if(!string.IsNullOrEmpty(BindProperty.Value)) { inputVSG = application.InputDatasets.BindData(BindProperty); } outputVS outputVSG = application.InputDatasets.BindData(BindProperty1); if (string.IsNullOrEmpty(BindProperty2.Value)) { //Form the list of input parameters. string[] inputParametersSplit = inputParametersINPUT.Value.Split(new char[] { ',' }, StringSplitOptions.RemoveEmptyEntries); int nParameters = inputParametersSplit.Length; application.EventGroup.InsertInfoEvent(String.Format("nParams is {0}", nParameters)); //Build a list of parameters to use for the classification. List> inputParamListList = new List>(); List> outputParamListList = new List>(); foreach (var paramName in inputParametersSplit) { var inputParamList = inputVSG.FeatureTable.GetValueList(paramName).AsValueList(); var outputParamList = outputVSG.FeatureTable.GetValueList(paramName).AsValueList(); if (inputParamList == null || outputParamList == null) { //Throw new addin exception. continue; } inputParamListList.Add(inputParamList); outputParamListList.Add(outputParamList); application.EventGroup.InsertInfoEvent(String.Format("params added {0}", paramName)); } //build an array to hold the parameters for the random forest operation. double[,] array = new double[inputVSG.FeatureTable.Count, nParameters + 1]; //populate the array. for (int i = 1; i < inputVSG.FeatureTable.Count - 1; i++) { array[i - 1, 0] = inputVSG.FeatureTable[i].classification; for (int param = 0; param < nParameters; param++) { try { array[i - 1, param + 1] = inputParamListList[param][i]; } catch (NoDataException) { array[i - 1, param + 1] = .5; } } } application.EventGroup.InsertInfoEvent(String.Format("The input array was built.")); //Create the random forest. int number = 0; dforest.decisionforest df = new dforest.decisionforest(); dforest.dfreport dfreport = new dforest.dfreport(); dforest.dfbuildrandomdecisionforest(ref array, inputVSG.FeatureTable.Count, nParameters, nClasses, nTrees, rValue, ref number, ref df, ref dfreport); application.EventGroup.InsertInfoEvent(String.Format("The tree was built.")); //Create report statistics regarding the random forest. application.EventGroup.InsertInfoEvent(String.Format("OOB RMS error is: {0}, OOB average relative error is: {1}", dfreport.oobrmserror, dfreport.oobavgrelerror)); //Create data structures to hold the results of the classification of the test dataset. double[] outputAttributesArray = new double[nParameters]; double[] result = new double[1]; int count = 0; int rowcount = outputVSG.FeatureTable.Count; var progEvent = application.EventGroup.InsertProgressEvent("Assigning probabilities"); foreach (var row in outputVSG.FeatureTable.Rows) { if (row.RowIndex == 0) { continue; } for (int param = 0; param < nParameters; param++) { try { outputAttributesArray[param] = outputParamListList[param][row.RowIndex]; } catch (NoDataException) { outputAttributesArray[param] = .5; } } dforest.dfprocess(ref df, ref outputAttributesArray, ref result); row.rank = result[0]; count++; progEvent.UpdateProgressPercent(count, rowcount); } var dfmodelStore = application.InputDatasets[1].CreateTable("forest model"); var rfList = dfmodelStore.CreateValueList("forestBit"); dfmodelStore.SetListBinding(v => v.dfModel, rfList); int length = 0; double[] dfSerial = new double[1]; dforest.dfserialize(ref df, ref dfSerial, ref length); foreach (var row in dfSerial) { var rowSerial = dfmodelStore.AllocateOne(); rowSerial.dfModel = row; } return application.InputDatasets[1]; } else { DF existingRF = application.InputDatasets.BindData(BindProperty2); //There is an existing RF so lets try to deserialise it. dforest.decisionforest df = new dforest.decisionforest(); double[] serialModel = new double[existingRF.Count]; for (int i = 0; i < existingRF.Count ; i++) { serialModel[i] = existingRF[i].dfModel; } dforest.dfunserialize(ref serialModel, ref df); int n = df.ntrees; //Form the list of input parameters. string[] inputParametersSplit = inputParametersINPUT.Value.Split(new char[] { ',' }, StringSplitOptions.RemoveEmptyEntries); int nParameters = inputParametersSplit.Length; application.EventGroup.InsertInfoEvent(String.Format("nParams is {0}", nParameters)); //Build a list of parameters to use for the classification. List> outputParamListList = new List>(); foreach (var paramName in inputParametersSplit) { var outputParamList = outputVSG.FeatureTable.GetValueList(paramName).AsValueList(); if (outputParamList == null) { //Throw new addin exception. continue; } outputParamListList.Add(outputParamList); application.EventGroup.InsertInfoEvent(String.Format("params added {0}", paramName)); } //Create data structures to hold the results of the classification of the test dataset. double[] outputAttributesArray = new double[nParameters]; double[] result = new double[1]; int count = 0; int rowcount = outputVSG.FeatureTable.Count; var progEvent = application.EventGroup.InsertProgressEvent("Assigning probabilities"); foreach (var row in outputVSG.FeatureTable.Rows) { if (row.RowIndex == 0) { continue; } for (int param = 0; param < nParameters; param++) { try { outputAttributesArray[param] = outputParamListList[param][row.RowIndex]; } catch (NoDataException) { outputAttributesArray[param] = .5; } } dforest.dfprocess(ref df, ref outputAttributesArray, ref result); row.rank = result[0]; count++; progEvent.UpdateProgressPercent(count, rowcount); } return application.InputDatasets[1]; } } #region AddIn properties BindingProperty BindProperty = new BindingProperty { PropertyName = "Training VS", PropertyDescription = "Vector set containing the training data.", DefaultValue = "Training", InputSocketIndex = 0 }; BindingProperty BindProperty1 = new BindingProperty { PropertyName = "Full VS", PropertyDescription = "Vector set containing all data to be compared.", DefaultValue = "Full", InputSocketIndex = 1 }; StringProperty inputParametersINPUT = new StringProperty { PropertyName = "input parameters", PropertyDescription = "A comma separated list of input parameters", DefaultValue = "att1, att2..." }; IntProperty nClasses = new IntProperty { PropertyName = "Number of classes", PropertyDescription = "A number representing the number of classes.", DefaultValue = 1 }; IntProperty nTrees = new IntProperty { PropertyName = "Number of trees", PropertyDescription = "A number representing the number of trees.", DefaultValue = 50 }; DoubleProperty rValue = new DoubleProperty { PropertyName = "R value", PropertyDescription = "A number between 0 and 1 representing the proportion of the data to use.", DefaultValue = 0.66 }; BindingProperty BindProperty2 = new BindingProperty { PropertyName = "RF Table (optional)", PropertyDescription = "A table which holds the previous RF.", DefaultValue = "", InputSocketIndex = 0 }; //nClassifications //nTrees //R (0 - 1) how much of the dataset to use. #endregion #region AddIn description public int InputSocketCount { get { return 2; } } public string Name { get { return "rForest"; } } public string Category { get { return "add-in UGM"; } } public string Description { get { return "This add-in contains code from "; } } public string Author { get { return "Alex Leith"; } } public System.Collections.Generic.IEnumerable Descriptors { get { // To add additional, optional authoring information, remove the "return null" statement and uncomment the // required fields. The descriptors can be reordered and/or repeated as desired, and formatted using RTF markup. return null; //yield return new DateCreatedDescriptor("(insert date created here)"); //yield return new DateModifiedDescriptor("(insert date modified here)"); //yield return new LineBreakDescriptor(); //yield return new OrganizationDescriptor("(insert organization here)"); //yield return new URLDescriptor("(insert URL here)"); //yield return new EmailDescriptor("(insert email address here)"); //yield return new LineBreakDescriptor(); //yield return new LicenseDescriptor("(insert license details here)"); //yield return new CopyrightDescriptor("(insert copyright notice here)"); //yield return new ReferencesDescriptor("(insert reference here)"); //yield return new LineBreakDescriptor(); //yield return new UserDescriptor("(insert additional heading here)", "(insert additional descriptor text here)"); } } #endregion } }