public class AdaptiveSGDROIFinder extends MLROIConfFinder
confThr
Constructor and Description |
---|
AdaptiveSGDROIFinder()
Create a flaw finder with 2 categories - is / isn't a flaw, and a linear combination of Laplacian and
Gaussian regularization.
|
AdaptiveSGDROIFinder(int numCats,
org.apache.mahout.classifier.sgd.PriorFunction priorFunction)
Creates a new AdaptiveSGDROIFinder with the specified number of categories and the specified
regularization function.
|
AdaptiveSGDROIFinder(int numCats,
org.apache.mahout.classifier.sgd.PriorFunction priorFunction,
int threadCount,
int poolSize) |
Modifier and Type | Method and Description |
---|---|
org.apache.mahout.math.Vector |
classify(org.apache.mahout.math.DenseVector d) |
void |
close()
Shuts down the worker pool used during training and best model determination.
|
org.apache.mahout.classifier.sgd.AdaptiveLogisticRegression |
getModel()
Returns the current meta-model.
|
int |
getNumCategories()
Return the number of categories for this model.
|
int |
getNumFeatures()
Return the number of features for this model.
|
java.util.Map<java.lang.String,java.lang.Object> |
getObjectMap()
Creates a map of the important fields for the instance, suitable for serialization.
|
int |
getPoolSize()
Returns the number of SGD learners in the meta-model.
|
org.apache.mahout.classifier.sgd.PriorFunction |
getPriorFunction()
Returns the regularization function used to limit overfitting.
|
long |
getSerializationVersion()
Returns the current version of the serialization format.
|
int |
getThreadCount()
Returns the number of threads used in training the meta-model learners
|
int |
getVersion()
Returns the current class version.
|
void |
initCurrentVersion(java.util.Map<java.lang.String,java.lang.Object> objectMap)
Initializes an instance with a current-version object graph.
|
boolean |
isROI(Dataset dataset)
Predict if a sample contains a region of interest.
|
boolean |
isROI(double[] data)
Predict if a sample contains a region of interest.
|
void |
legacyRead(com.esotericsoftware.kryo.Kryo kryo,
com.esotericsoftware.kryo.io.Input input) |
void |
legacyWrite(com.esotericsoftware.kryo.Kryo kryo,
com.esotericsoftware.kryo.io.Output output) |
double |
negativeClass()
Returns the numeric value of the negative class of a two-category model.
|
double |
positiveClass()
Returns the numeric value of the positive class of a two-category model.
|
void |
setModel(org.apache.mahout.classifier.sgd.AdaptiveLogisticRegression model)
Sets the current meta-model
|
void |
setNumCategories(int numCategories)
Sets the number of categories for this model.
|
void |
setNumFeatures(int numFeatures)
Sets the number of features for this model.
|
void |
setPoolSize(int poolSize)
Sets the number of SGD learners in the meta-model.
|
void |
setPriorFunction(org.apache.mahout.classifier.sgd.PriorFunction priorFunction)
Sets the regularization function used to limit overfitting.
|
void |
setThreadCount(int threadCount)
Sets the number of threads to use during training.
|
void |
train(double[][] X,
int[] y)
Trains the current model with additional samples, creating the model if required.
|
getConfidenceThreshold, predict, predict, setConfidenceThreshold, writableWriteToBytes
clone, equals, finalize, getClass, hashCode, notify, notifyAll, toString, wait, wait, wait
init, initPreviousVersion, initUnknownVersion, load, read, save, write
predict_proba, predict_proba
public AdaptiveSGDROIFinder(int numCats, org.apache.mahout.classifier.sgd.PriorFunction priorFunction)
numCats
- number of categoriespriorFunction
- prior function to usepublic AdaptiveSGDROIFinder(int numCats, org.apache.mahout.classifier.sgd.PriorFunction priorFunction, int threadCount, int poolSize)
numCats
- number of labels in the data (e.g. 2 for is/is not a flaw)priorFunction
- Regularization function to penalize complex modelsthreadCount
- number of cores to use during training. Defaults to number of cores on system.poolSize
- number of learners to train. Defaults to 20.public AdaptiveSGDROIFinder()
public void train(double[][] X, int[] y) throws java.lang.Exception
X
- N examples with M features per exampley
- N labels for the N examples in Xjava.lang.Exception
- if an error occurspublic org.apache.mahout.math.Vector classify(org.apache.mahout.math.DenseVector d)
public boolean isROI(double[] data)
data
- raw data to examinepublic boolean isROI(Dataset dataset)
dataset
- data to examinepublic long getSerializationVersion()
ObjectMap
public int getVersion()
ObjectMap
public java.util.Map<java.lang.String,java.lang.Object> getObjectMap()
ObjectMap
public void initCurrentVersion(java.util.Map<java.lang.String,java.lang.Object> objectMap)
ObjectMap
objectMap
- object graph for initializationpublic int getNumCategories()
public void setNumCategories(int numCategories)
numCategories
- number of categoriespublic int getNumFeatures()
public void setNumFeatures(int numFeatures)
numFeatures
- number of featurespublic org.apache.mahout.classifier.sgd.PriorFunction getPriorFunction()
public void setPriorFunction(org.apache.mahout.classifier.sgd.PriorFunction priorFunction)
priorFunction
- new regularization function.public org.apache.mahout.classifier.sgd.AdaptiveLogisticRegression getModel()
public void setModel(org.apache.mahout.classifier.sgd.AdaptiveLogisticRegression model)
model
- new meta-modelpublic double positiveClass()
public double negativeClass()
public void close()
public int getThreadCount()
public void setThreadCount(int threadCount)
threadCount
- new number of threads (cores)public int getPoolSize()
public void setPoolSize(int poolSize)
poolSize
- new number of SGD learnerspublic void legacyWrite(com.esotericsoftware.kryo.Kryo kryo, com.esotericsoftware.kryo.io.Output output)
public void legacyRead(com.esotericsoftware.kryo.Kryo kryo, com.esotericsoftware.kryo.io.Input input)