Skip to content

Commit

Permalink
add benchmark interface to BeagelTreeLikelihood #1172
Browse files Browse the repository at this point in the history
  • Loading branch information
rbouckaert committed Oct 31, 2024
1 parent 5eacf16 commit 65fa457
Showing 1 changed file with 57 additions and 1 deletion.
58 changes: 57 additions & 1 deletion src/beast/base/evolution/likelihood/BeagleTreeLikelihood.java
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,17 @@

package beast.base.evolution.likelihood;


import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

import beagle.Beagle;
import beagle.BeagleBenchmarkFlag;
import beagle.BeagleFactory;
import beagle.BeagleFlag;
import beagle.BeagleInfo;
import beagle.BenchmarkedResourceDetails;
import beagle.InstanceDetails;
import beagle.ResourceDetails;
import beast.base.core.Description;
Expand Down Expand Up @@ -63,6 +66,7 @@ public class BeagleTreeLikelihood extends TreeLikelihood {
// will wrap around.
// note: to use a different device, say device 2, start beast with
// java -Dbeagle.resource.order=2 beast.app.BeastMCMC
private static final String RESOURCE_AUTO_PROPERTY = "beagle.resource.auto";
private static final String RESOURCE_ORDER_PROPERTY = "beagle.resource.order";
private static final String PREFERRED_FLAGS_PROPERTY = "beagle.preferred.flags";
private static final String REQUIRED_FLAGS_PROPERTY = "beagle.required.flags";
Expand Down Expand Up @@ -92,7 +96,7 @@ public class BeagleTreeLikelihood extends TreeLikelihood {
private double [] currentCategoryWeights;

private int invariantCategory = -1;

@Override
public void initAndValidate() {
boolean forceJava = Boolean.valueOf(System.getProperty("java.only"));
Expand Down Expand Up @@ -273,6 +277,53 @@ private boolean initialize() {
requirementFlags |= BeagleFlag.EIGEN_COMPLEX.getMask();
}

// start auto resource selection
String resourceAuto = System.getProperty(RESOURCE_AUTO_PROPERTY);
if (resourceAuto != null && Boolean.parseBoolean(resourceAuto)) {

long benchmarkFlags = 0;

if (this.rescalingScheme == PartialsRescalingScheme.NONE) {
benchmarkFlags = BeagleBenchmarkFlag.SCALING_NONE.getMask();
} else if (this.rescalingScheme == PartialsRescalingScheme.ALWAYS) {
benchmarkFlags = BeagleBenchmarkFlag.SCALING_ALWAYS.getMask();
} else {
benchmarkFlags = BeagleBenchmarkFlag.SCALING_DYNAMIC.getMask();
}

Log.warning("\nRunning benchmarks to automatically select fastest BEAGLE resource for analysis... ");

List<BenchmarkedResourceDetails> benchmarkedResourceDetails =
BeagleFactory.getBenchmarkedResourceDetails(
tipCount,
compactPartialsCount,
m_nStateCount,
patternCount,
categoryCount,
resourceList,
preferenceFlags,
requirementFlags,
1, // eigenModelCount,
1,
0, // calculateDerivatives,
benchmarkFlags);


Log.warning(" Benchmark results, from fastest to slowest:");

for (BenchmarkedResourceDetails benchmarkedResource : benchmarkedResourceDetails) {
Log.warning(benchmarkedResource.toString());
}

long benchedFlags = benchmarkedResourceDetails.get(0).getBenchedFlags();
// if ((benchedFlags & BeagleFlag.FRAMEWORK_CPU.getMask()) != 0) {
// throw new DelegateTypeException();
// }

resourceList = new int[]{benchmarkedResourceDetails.get(0).getResourceNumber()};
}
// end auto resource selection

instanceCount++;

try {
Expand Down Expand Up @@ -1212,4 +1263,9 @@ public static PartialsRescalingScheme parseFromString(String text) {
return patternLogLikelihoods.clone();
}

@Override
public boolean isInitialisedSuccesfully() {
return beagle != null;
}

}

0 comments on commit 65fa457

Please sign in to comment.