Skip to content
This repository has been archived by the owner on Sep 20, 2022. It is now read-only.

Commit

Permalink
[HIVEMALL-101] Fixed CI errors
Browse files Browse the repository at this point in the history
  • Loading branch information
myui committed Jun 15, 2017
1 parent 50b4c9a commit 65d92ff
Show file tree
Hide file tree
Showing 8 changed files with 99 additions and 29 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ public final class GeneralClassifierUDTF extends GeneralLearnerBaseUDTF {

@Override
protected String getLossOptionDescription() {
return "Loss function [default: HingeLoss, LogLoss, SquaredHingeLoss, ModifiedHuberLoss, "
+ "SquaredLoss, QuantileLoss, EpsilonInsensitiveLoss, HuberLoss]";
return "Loss function [HingeLoss (default), LogLoss, SquaredHingeLoss, ModifiedHuberLoss, \n"
+ ", or a regression loss: SquaredLoss, QuantileLoss, EpsilonInsensitiveLoss, HuberLoss]";
}

@Override
Expand All @@ -51,11 +51,9 @@ protected LossType getDefaultLossType() {
}

@Override
protected void checkLossFunction(LossFunction lossFunction) throws UDFArgumentException {
if(!lossFunction.forBinaryClassification()) {
throw new UDFArgumentException("The loss function `" + lossFunction.getType()
+ "` is not designed for binary classification");
}
protected void checkLossFunction(@Nonnull LossFunction lossFunction)
throws UDFArgumentException {
// will accepts both binary loss and
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,13 @@ public static Optimizer create(int ndims, @Nonnull Map<String, String> options)
throw new IllegalArgumentException("`optimizer` not defined");
}

if ("rda".equalsIgnoreCase(options.get("regularization"))
&& "adagrad".equalsIgnoreCase(optimizerName) == false) {
throw new IllegalArgumentException(
"`-regularization rda` is only supported for AdaGrad but `-optimizer "
+ optimizerName);
}

final Optimizer optimizerImpl;
if ("sgd".equalsIgnoreCase(optimizerName)) {
optimizerImpl = new Optimizer.SGD(options);
Expand Down
82 changes: 69 additions & 13 deletions core/src/main/java/hivemall/optimizer/LossFunctions.java
Original file line number Diff line number Diff line change
Expand Up @@ -29,27 +29,32 @@
public final class LossFunctions {

public enum LossType {
SquaredLoss, QuantileLoss, EpsilonInsensitiveLoss, HuberLoss, HingeLoss, LogLoss,
SquaredHingeLoss, ModifiedHuberLoss
SquaredLoss, QuantileLoss, EpsilonInsensitiveLoss, SquaredEpsilonInsensitiveLoss,
HuberLoss, HingeLoss, LogLoss, SquaredHingeLoss, ModifiedHuberLoss
}

@Nonnull
public static LossFunction getLossFunction(@Nullable final String type) {
if ("SquaredLoss".equalsIgnoreCase(type)) {
final String t = type.toLowerCase();
if ("squaredloss".equals(t) || "squared".equals(t)) {
return new SquaredLoss();
} else if ("QuantileLoss".equalsIgnoreCase(type)) {
} else if ("quantileloss".equals(t) || "quantile".equals(t)) {
return new QuantileLoss();
} else if ("EpsilonInsensitiveLoss".equalsIgnoreCase(type)) {
} else if ("epsiloninsensitiveloss".equals(t) || "epsilon_insensitive".equals(t)) {
return new EpsilonInsensitiveLoss();
} else if ("HuberLoss".equalsIgnoreCase(type)) {
} else if ("squaredepsiloninsensitiveloss".equals(t)
|| "squared_epsilon_insensitive".equals(t)) {
return new SquaredEpsilonInsensitiveLoss();
} else if ("huberloss".equals(t) || "huber".equals(t)) {
return new HuberLoss();
} else if ("HingeLoss".equalsIgnoreCase(type)) {
} else if ("hingeloss".equals(t) || "hinge".equals(t)) {
return new HingeLoss();
} else if ("LogLoss".equalsIgnoreCase(type) || "LogisticLoss".equalsIgnoreCase(type)) {
} else if ("logloss".equals(t) || "log".equals(t) || "logisticloss".equals(t)
|| "logistic".equals(t)) {
return new LogLoss();
} else if ("SquaredHingeLoss".equalsIgnoreCase(type)) {
} else if ("squaredhingeloss".equals(t) || "squared_hinge".equals(t)) {
return new SquaredHingeLoss();
} else if ("ModifiedHuberLoss".equalsIgnoreCase(type)) {
} else if ("modifiedhuberloss".equals(t) || "modified_huber".equals(t)) {
return new ModifiedHuberLoss();
}
throw new IllegalArgumentException("Unsupported loss function name: " + type);
Expand All @@ -64,6 +69,8 @@ public static LossFunction getLossFunction(@Nonnull final LossType type) {
return new QuantileLoss();
case EpsilonInsensitiveLoss:
return new EpsilonInsensitiveLoss();
case SquaredEpsilonInsensitiveLoss:
return new SquaredEpsilonInsensitiveLoss();
case HuberLoss:
return new HuberLoss();
case HingeLoss:
Expand Down Expand Up @@ -272,11 +279,11 @@ public double loss(final double p, final double y) {
public float dloss(final float p, final float y) {
if ((y - p) > epsilon) {// real value > predicted value - epsilon
return -1.f;
}
if ((p - y) > epsilon) {// real value < predicted value - epsilon
} else if ((p - y) > epsilon) {// real value < predicted value - epsilon
return 1.f;
} else {
return 0.f;
}
return 0.f;
}

@Override
Expand All @@ -285,6 +292,55 @@ public LossType getType() {
}
}

/**
* Squared Epsilon-Insensitive loss. <code>loss = max(0, |y - p| - epsilon)^2</code>
*/
public static final class SquaredEpsilonInsensitiveLoss extends RegressionLoss {

private float epsilon;

public SquaredEpsilonInsensitiveLoss() {
this(0.1f);
}

public SquaredEpsilonInsensitiveLoss(float epsilon) {
this.epsilon = epsilon;
}

public void setEpsilon(float epsilon) {
this.epsilon = epsilon;
}

@Override
public float loss(final float p, final float y) {
float d = Math.abs(y - p) - epsilon;
return (d > 0.f) ? (d * d) : 0.f;
}

@Override
public double loss(final double p, final double y) {
double d = Math.abs(y - p) - epsilon;
return (d > 0.d) ? (d * d) : 0.d;
}

@Override
public float dloss(final float p, final float y) {
final float z = y - p;
if (z > epsilon) {
return -2 * (z - epsilon);
} else if (-z > epsilon) {
return 2 * (-z - epsilon);
} else {
return 0.f;
}
}

@Override
public LossType getType() {
return LossType.SquaredEpsilonInsensitiveLoss;
}
}

/**
* Huber regression loss.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,13 @@ public static Optimizer create(int ndims, @Nonnull Map<String, String> options)
throw new IllegalArgumentException("`optimizer` not defined");
}

if ("rda".equalsIgnoreCase(options.get("regularization"))
&& "adagrad".equalsIgnoreCase(optimizerName) == false) {
throw new IllegalArgumentException(
"`-regularization rda` is only supported for AdaGrad but `-optimizer "
+ optimizerName);
}

final Optimizer optimizerImpl;
if ("sgd".equalsIgnoreCase(optimizerName)) {
optimizerImpl = new Optimizer.SGD(options);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,8 @@ public final class GeneralRegressionUDTF extends GeneralLearnerBaseUDTF {

@Override
protected String getLossOptionDescription() {
return "Loss function [default: SquaredLoss, QuantileLoss, EpsilonInsensitiveLoss, HuberLoss]";
return "Loss function [default: SquaredLoss/squared, QuantileLoss/quantile, "
+ "EpsilonInsensitiveLoss/epsilon_insensitive, HuberLoss/huber]";
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
*/
package hivemall.classifier;

import hivemall.utils.math.MathUtils;

import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStream;
Expand All @@ -29,7 +31,7 @@
import java.util.StringTokenizer;
import java.util.zip.GZIPInputStream;

import hivemall.utils.math.MathUtils;
import javax.annotation.Nonnull;

import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
Expand All @@ -41,8 +43,6 @@
import org.junit.Assert;
import org.junit.Test;

import javax.annotation.Nonnull;

public class GeneralClassifierUDTFTest {
private static final boolean DEBUG = false;

Expand Down Expand Up @@ -148,7 +148,7 @@ public void test() throws Exception {
String[] regularizations = new String[] {"NO", "L1", "L2", "ElasticNet", "RDA"};
String[] lossFunctions = new String[] {"HingeLoss", "LogLoss", "SquaredHingeLoss",
"ModifiedHuberLoss", "SquaredLoss", "QuantileLoss", "EpsilonInsensitiveLoss",
"HuberLoss"};
"SquaredEpsilonInsensitiveLoss", "HuberLoss"};

for (String opt : optimizers) {
for (String reg : regularizations) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,18 +22,17 @@
import java.util.Arrays;
import java.util.List;

import javax.annotation.Nonnull;

import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;

import org.junit.Assert;
import org.junit.Test;

import javax.annotation.Nonnull;

public class GeneralRegressionUDTFTest {
private static final boolean DEBUG = false;

Expand Down Expand Up @@ -156,7 +155,7 @@ public void test() throws Exception {
String[] optimizers = new String[] {"SGD", "AdaDelta", "AdaGrad", "Adam"};
String[] regularizations = new String[] {"NO", "L1", "L2", "ElasticNet", "RDA"};
String[] lossFunctions = new String[] {"SquaredLoss", "QuantileLoss",
"EpsilonInsensitiveLoss", "HuberLoss"};
"EpsilonInsensitiveLoss", "SquaredEpsilonInsensitiveLoss", "HuberLoss"};

for (String opt : optimizers) {
for (String reg : regularizations) {
Expand Down
2 changes: 2 additions & 0 deletions docs/gitbook/misc/prediction.md
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ Below we list possible options for `train_regression` and `train_classifier`, an
- SquaredLoss
- QuantileLoss
- EpsilonInsensitiveLoss
- SquaredEpsilonInsensitiveLoss
- HuberLoss
- For `train_classifier`
- HingeLoss
Expand All @@ -119,6 +120,7 @@ Below we list possible options for `train_regression` and `train_classifier`, an
- SquaredLoss
- QuantileLoss
- EpsilonInsensitiveLoss
- SquaredEpsilonInsensitiveLoss
- HuberLoss
- Regularization function: `-reg`, `-regularization`
- L1
Expand Down

0 comments on commit 65d92ff

Please sign in to comment.