Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ protected final boolean isCompleted() {
return state.get() == State.COMPLETED || state.get() == State.LOCAL_ABORTED;
}

boolean markAsCancelled() {
protected boolean markAsCancelled() {
return state.compareAndSet(State.STARTED, State.PENDING_CANCEL);
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.license;

import org.elasticsearch.persistent.AllocatedPersistentTask;
import org.elasticsearch.persistent.PersistentTasksService;
import org.elasticsearch.tasks.TaskId;
import org.elasticsearch.tasks.TaskManager;

import java.util.Map;

/**
* An AllocatedPersistentTask which automatically tracks as a licensed feature usage.
*/
public class LicensedAllocatedPersistentTask extends AllocatedPersistentTask {
private final LicensedFeature.Persistent licensedFeature;
private final String featureContext;
private final XPackLicenseState licenseState;

public LicensedAllocatedPersistentTask(long id, String type, String action, String description, TaskId parentTask,
Map<String, String> headers, LicensedFeature.Persistent feature, String featureContext,
XPackLicenseState licenseState) {
super(id, type, action, description, parentTask, headers);
this.licensedFeature = feature;
this.featureContext = featureContext;
this.licenseState = licenseState;
licensedFeature.startTracking(licenseState, featureContext);

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking at the taskManager code, I think there is a small chance that a task object is destroyed but none of the cancellations methods are called.

try {
task.init(persistentTasksService, taskManager, taskInProgress.getId(), taskInProgress.getAllocationId());
logger.trace("Persistent task [{}] with id [{}] and allocation id [{}] was created", task.getAction(),
task.getPersistentTaskId(), task.getAllocationId());
try {
runningTasks.put(taskInProgress.getAllocationId(), task);
nodePersistentTasksExecutor.executeTask(taskInProgress.getParams(), taskInProgress.getState(), task, executor);
} catch (Exception e) {
// Submit task failure
task.markAsFailed(e);
}
processed = true;
} catch (Exception e) {
initializationException = e;
} finally {
if (processed == false) {
// something went wrong - unregistering task
logger.warn("Persistent task [{}] with id [{}] and allocation id [{}] failed to create", task.getAction(),
task.getPersistentTaskId(), task.getAllocationId());
taskManager.unregister(task);
if (initializationException != null) {
notifyMasterOfFailedTask(taskInProgress, initializationException);
}
}
}

Note if init is called and throws, only taskManager.unregister(task); is called. I don't think this calls any of these internal life time tracking methods.

I am thinking that licensedFeature.startTracking(licenseState, featureContext); should probably be called in init to avoid this weird condition.

What do you think?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems like init is being abused currently. It's purpose, IMO is to just set some references, not to do any work? But at least rollup seems to utilize this.

What if instead I were to make init public and final? Public is so that tests can still call it, and then final so that there is no chance some extra work is done that could actually throw an exception.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@rjernst if we can make that guarantee of it not throwing, it makes sense to me. All it's doing is setting internal values...so I don't know why it would ever throw.

So, I think fix init, or somehow make sure that stopTracking is called if init throws in this execution path.

}

private void stopTracking() {
licensedFeature.stopTracking(licenseState, featureContext);
}

@Override
protected final boolean markAsCancelled() {
stopTracking();
return doMarkAsCancelled();
}

protected boolean doMarkAsCancelled() {
return super.markAsCancelled();
}

@Override
public final void markAsCompleted() {
stopTracking();
doMarkAsCompleted();
}

protected void doMarkAsCompleted() {
super.markAsCompleted();
}

@Override
public final void markAsFailed(Exception e) {
stopTracking();
doMarkAsFailed(e);
}

protected void doMarkAsFailed(Exception e) {
super.markAsFailed(e);
}

@Override
public final void markAsLocallyAborted(String localAbortReason) {
stopTracking();
doMarkAsLocallyAborted(localAbortReason);
}

protected void doMarkAsLocallyAborted(String localAbortReason) {
super.markAsLocallyAborted(localAbortReason);
}

// this is made public for tests, and final to ensure it is not overridden with something that may throw
@Override
public final void init(PersistentTasksService persistentTasksService, TaskManager taskManager,
String persistentTaskId, long allocationId) {
super.init(persistentTasksService, taskManager, persistentTaskId, allocationId);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -51,13 +51,23 @@ private Persistent(String family, String name, License.OperationMode minimumOper
*/
public boolean checkAndStartTracking(XPackLicenseState state, String contextName) {
if (state.isAllowed(this)) {
state.enableUsageTracking(this, contextName);
startTracking(state, contextName);
return true;
} else {
return false;
}
}

/**
* Starts tracking the feature.
*
* This is an alternative to {@link #checkAndStartTracking(XPackLicenseState, String)}
* where the license state has already been checked.
*/
public void startTracking(XPackLicenseState state, String contextName) {
state.enableUsageTracking(this, contextName);
}

/**
* Stop tracking the feature so that the current time will be the last that it was used.
*/
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.license;

import org.elasticsearch.persistent.PersistentTasksService;
import org.elasticsearch.tasks.TaskId;
import org.elasticsearch.tasks.TaskManager;
import org.elasticsearch.test.ESTestCase;

import java.util.Map;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.Consumer;

import static org.hamcrest.Matchers.is;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;

public class LicensedAllocatedPersistentTaskTests extends ESTestCase {

void assertTrackingComplete(Consumer<LicensedAllocatedPersistentTask> method) {
XPackLicenseState licenseState = mock(XPackLicenseState.class);
LicensedFeature.Persistent feature = LicensedFeature.persistent("family", "somefeature", License.OperationMode.PLATINUM);
var task = new LicensedAllocatedPersistentTask(0, "type", "action", "description", TaskId.EMPTY_TASK_ID, Map.of(),
feature, "context", licenseState);
PersistentTasksService service = mock(PersistentTasksService.class);
TaskManager taskManager = mock(TaskManager.class);
task.init(service, taskManager, "id", 0);
verify(licenseState, times(1)).enableUsageTracking(feature, "context");
method.accept(task);
verify(licenseState, times(1)).disableUsageTracking(feature, "context");
}

public void testCompleted() {
assertTrackingComplete(LicensedAllocatedPersistentTask::markAsCompleted);
}

public void testCancelled() {
assertTrackingComplete(LicensedAllocatedPersistentTask::markAsCancelled);
}

public void testFailed() {
assertTrackingComplete(t -> t.markAsFailed(null));
}

public void testLocallyAborted() {
assertTrackingComplete(t -> t.markAsLocallyAborted("reason"));
}

public void testDoOverrides() {
XPackLicenseState licenseState = mock(XPackLicenseState.class);
LicensedFeature.Persistent feature = LicensedFeature.persistent("family", "somefeature", License.OperationMode.PLATINUM);

AtomicBoolean completedCalled = new AtomicBoolean();
AtomicBoolean cancelledCalled = new AtomicBoolean();
AtomicBoolean failedCalled = new AtomicBoolean();
AtomicBoolean abortedCalled = new AtomicBoolean();
var task = new LicensedAllocatedPersistentTask(0, "type", "action", "description", TaskId.EMPTY_TASK_ID, Map.of(),
feature, "context", licenseState) {
@Override
protected boolean doMarkAsCancelled() {
cancelledCalled.set(true);
return true;
}
@Override
protected void doMarkAsCompleted() {
completedCalled.set(true);
}
@Override
protected void doMarkAsFailed(Exception e) {
failedCalled.set(true);
}
@Override
protected void doMarkAsLocallyAborted(String reason) {
abortedCalled.set(true);
}
};

task.markAsCancelled();
assertThat(cancelledCalled.get(), is(true));
task.markAsCompleted();
assertThat(completedCalled.get(), is(true));
task.markAsFailed(null);
assertThat(failedCalled.get(), is(true));
task.markAsLocallyAborted("reason");
assertThat(abortedCalled.get(), is(true));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@
import org.elasticsearch.indices.analysis.AnalysisModule.AnalysisProvider;
import org.elasticsearch.indices.breaker.BreakerSettings;
import org.elasticsearch.ingest.Processor;
import org.elasticsearch.license.License;
import org.elasticsearch.license.LicensedFeature;
import org.elasticsearch.license.XPackLicenseState;
import org.elasticsearch.monitor.jvm.JvmInfo;
import org.elasticsearch.monitor.os.OsProbe;
Expand Down Expand Up @@ -441,6 +443,9 @@ public class MachineLearning extends Plugin implements SystemIndexPlugin,
// Recompile if you want to compare performance with C++ tokenization.
public static final boolean CATEGORIZATION_TOKENIZATION_IN_JAVA = true;

public static final LicensedFeature.Persistent ML_JOBS_FEATURE =
LicensedFeature.persistent("machine-learning", "anomaly-detection-job", License.OperationMode.PLATINUM);

@Override
public Map<String, Processor.Factory> getProcessors(Processor.Parameters parameters) {
if (this.enabled == false) {
Expand Down Expand Up @@ -942,7 +947,8 @@ public List<PersistentTasksExecutor<?>> getPersistentTasksExecutor(ClusterServic
datafeedConfigProvider.get(),
memoryTracker.get(),
client,
expressionResolver),
expressionResolver,
getLicenseState()),
new TransportStartDatafeedAction.StartDatafeedPersistentTasksExecutor(datafeedRunner.get(), expressionResolver),
new TransportStartDataFrameAnalyticsAction.TaskExecutor(settings,
client,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,17 @@
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.logging.log4j.message.ParameterizedMessage;
import org.elasticsearch.persistent.AllocatedPersistentTask;
import org.elasticsearch.license.LicensedAllocatedPersistentTask;
import org.elasticsearch.license.XPackLicenseState;
import org.elasticsearch.tasks.TaskId;
import org.elasticsearch.xpack.core.ml.action.OpenJobAction;
import org.elasticsearch.xpack.ml.MachineLearning;
import org.elasticsearch.xpack.ml.job.process.autodetect.AutodetectProcessManager;

import java.util.Map;
import java.util.concurrent.atomic.AtomicReference;

public class JobTask extends AllocatedPersistentTask implements OpenJobAction.JobTaskMatcher {
public class JobTask extends LicensedAllocatedPersistentTask implements OpenJobAction.JobTaskMatcher {

/**
* We should only progress forwards through these states: close takes precedence over vacate
Expand All @@ -33,8 +35,9 @@ enum ClosingOrVacating {
private final AtomicReference<ClosingOrVacating> closingOrVacating = new AtomicReference<>(ClosingOrVacating.NEITHER);
private volatile AutodetectProcessManager autodetectProcessManager;

JobTask(String jobId, long id, String type, String action, TaskId parentTask, Map<String, String> headers) {
super(id, type, action, "job-" + jobId, parentTask, headers);
protected JobTask(String jobId, long id, String type, String action, TaskId parentTask, Map<String, String> headers,
XPackLicenseState licenseState) {
super(id, type, action, "job-" + jobId, parentTask, headers, MachineLearning.ML_JOBS_FEATURE, "job-" + jobId, licenseState);
this.jobId = jobId;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.index.engine.DocumentMissingException;
import org.elasticsearch.license.XPackLicenseState;
import org.elasticsearch.persistent.AllocatedPersistentTask;
import org.elasticsearch.persistent.PersistentTaskState;
import org.elasticsearch.persistent.PersistentTasksCustomMetadata;
Expand Down Expand Up @@ -90,6 +91,7 @@ public static String[] indicesOfInterest(String resultsIndex) {
private final Client client;
private final JobResultsProvider jobResultsProvider;
private final AnomalyDetectionAuditor auditor;
private final XPackLicenseState licenseState;

private volatile ClusterState clusterState;

Expand All @@ -99,13 +101,15 @@ public OpenJobPersistentTasksExecutor(Settings settings,
DatafeedConfigProvider datafeedConfigProvider,
MlMemoryTracker memoryTracker,
Client client,
IndexNameExpressionResolver expressionResolver) {
IndexNameExpressionResolver expressionResolver,
XPackLicenseState licenseState) {
super(MlTasks.JOB_TASK_NAME, MachineLearning.UTILITY_THREAD_POOL_NAME, settings, clusterService, memoryTracker, expressionResolver);
this.autodetectProcessManager = Objects.requireNonNull(autodetectProcessManager);
this.datafeedConfigProvider = Objects.requireNonNull(datafeedConfigProvider);
this.client = Objects.requireNonNull(client);
this.jobResultsProvider = new JobResultsProvider(client, settings, expressionResolver);
this.auditor = new AnomalyDetectionAuditor(client, clusterService);
this.licenseState = licenseState;
clusterService.addListener(event -> clusterState = event.state());
}

Expand Down Expand Up @@ -395,7 +399,7 @@ private void openJob(JobTask jobTask) {
protected AllocatedPersistentTask createTask(long id, String type, String action, TaskId parentTaskId,
PersistentTasksCustomMetadata.PersistentTask<OpenJobAction.JobParams> persistentTask,
Map<String, String> headers) {
return new JobTask(persistentTask.getParams().getJobId(), id, type, action, parentTaskId, headers);
return new JobTask(persistentTask.getParams().getJobId(), id, type, action, parentTaskId, headers, licenseState);
}

public static Optional<ElasticsearchException> checkAssignmentState(PersistentTasksCustomMetadata.Assignment assignment,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@
import org.elasticsearch.env.TestEnvironment;
import org.elasticsearch.index.analysis.AnalysisRegistry;
import org.elasticsearch.indices.TestIndexNameExpressionResolver;
import org.elasticsearch.license.XPackLicenseState;
import org.elasticsearch.persistent.PersistentTasksService;
import org.elasticsearch.tasks.TaskId;
import org.elasticsearch.tasks.TaskManager;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.core.ml.annotations.AnnotationIndex;
Expand Down Expand Up @@ -74,12 +78,14 @@
import java.util.Date;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.Callable;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.BiConsumer;
Expand Down Expand Up @@ -579,12 +585,19 @@ public void testKillKillsAutodetectProcess() throws IOException {

public void testKillingAMissingJobFinishesTheTask() {
AutodetectProcessManager manager = createSpyManager();
JobTask jobTask = mock(JobTask.class);
when(jobTask.getJobId()).thenReturn("foo");
XPackLicenseState licenseState = mock(XPackLicenseState.class);
AtomicBoolean markCalled = new AtomicBoolean();
JobTask jobTask = new JobTask("foo", 0, "type", "action", TaskId.EMPTY_TASK_ID, Map.of(), licenseState) {
@Override
protected void doMarkAsCompleted() {
markCalled.set(true);
}
};
jobTask.init(mock(PersistentTasksService.class), mock(TaskManager.class), "taskid", 0);

manager.killProcess(jobTask, false, null);

verify(jobTask).markAsCompleted();
assertThat(markCalled.get(), is(true));
}

public void testProcessData_GivenStateNotOpened() {
Expand Down
Loading