[watsonstt] lazy abort (#12318)

* [watsonstt] lazy abort

Signed-off-by: Miguel Álvarez Díez <miguelwork92@gmail.com>
This commit is contained in:
GiviMAD 2022-02-20 14:25:47 +01:00 committed by GitHub
parent c03dbfd520
commit 27538fa87f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 31 additions and 22 deletions

View File

@ -19,6 +19,7 @@ import java.util.Locale;
import java.util.Map; import java.util.Map;
import java.util.Set; import java.util.Set;
import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicReference; import java.util.concurrent.atomic.AtomicReference;
import java.util.stream.Collectors; import java.util.stream.Collectors;
@ -132,12 +133,13 @@ public class WatsonSTTService implements STTService {
.speechDetectorSensitivity(config.speechDetectorSensitivity).inactivityTimeout(config.inactivityTimeout) .speechDetectorSensitivity(config.speechDetectorSensitivity).inactivityTimeout(config.inactivityTimeout)
.build(); .build();
final AtomicReference<@Nullable WebSocket> socketRef = new AtomicReference<>(); final AtomicReference<@Nullable WebSocket> socketRef = new AtomicReference<>();
var task = executor.submit(() -> { final AtomicBoolean aborted = new AtomicBoolean(false);
executor.submit(() -> {
int retries = 2; int retries = 2;
while (retries > 0) { while (retries > 0) {
try { try {
socketRef.set(speechToText.recognizeUsingWebSocket(wsOptions, socketRef.set(speechToText.recognizeUsingWebSocket(wsOptions,
new TranscriptionListener(sttListener, config))); new TranscriptionListener(sttListener, config, aborted)));
break; break;
} catch (RuntimeException e) { } catch (RuntimeException e) {
var cause = e.getCause(); var cause = e.getCause();
@ -157,16 +159,17 @@ public class WatsonSTTService implements STTService {
return new STTServiceHandle() { return new STTServiceHandle() {
@Override @Override
public void abort() { public void abort() {
var socket = socketRef.get(); if (!aborted.getAndSet(true)) {
if (socket != null) { var socket = socketRef.get();
socket.close(1000, null); if (socket != null) {
socket.cancel(); socket.close(1000, null);
try { socket.cancel();
Thread.sleep(100); try {
} catch (InterruptedException ignored) { Thread.sleep(100);
} catch (InterruptedException ignored) {
}
} }
} }
task.cancel(true);
} }
}; };
} }
@ -226,13 +229,15 @@ public class WatsonSTTService implements STTService {
private final StringBuilder transcriptBuilder = new StringBuilder(); private final StringBuilder transcriptBuilder = new StringBuilder();
private final STTListener sttListener; private final STTListener sttListener;
private final WatsonSTTConfiguration config; private final WatsonSTTConfiguration config;
private final AtomicBoolean aborted;
private float confidenceSum = 0f; private float confidenceSum = 0f;
private int responseCount = 0; private int responseCount = 0;
private boolean disconnected = false; private boolean disconnected = false;
public TranscriptionListener(STTListener sttListener, WatsonSTTConfiguration config) { public TranscriptionListener(STTListener sttListener, WatsonSTTConfiguration config, AtomicBoolean aborted) {
this.sttListener = sttListener; this.sttListener = sttListener;
this.config = config; this.config = config;
this.aborted = aborted;
} }
@Override @Override
@ -267,24 +272,28 @@ public class WatsonSTTService implements STTService {
return; return;
} }
logger.warn("TranscriptionError: {}", errorMessage); logger.warn("TranscriptionError: {}", errorMessage);
sttListener.sttEventReceived( if (!aborted.get()) {
new SpeechRecognitionErrorEvent(errorMessage != null ? errorMessage : "Unknown error")); sttListener.sttEventReceived(
new SpeechRecognitionErrorEvent(errorMessage != null ? errorMessage : "Unknown error"));
}
} }
@Override @Override
public void onDisconnected() { public void onDisconnected() {
logger.debug("onDisconnected"); logger.debug("onDisconnected");
disconnected = true; disconnected = true;
sttListener.sttEventReceived(new RecognitionStopEvent()); if (!aborted.getAndSet(true)) {
float averageConfidence = confidenceSum / (float) responseCount; sttListener.sttEventReceived(new RecognitionStopEvent());
String transcript = transcriptBuilder.toString(); float averageConfidence = confidenceSum / (float) responseCount;
if (!transcript.isBlank()) { String transcript = transcriptBuilder.toString();
sttListener.sttEventReceived(new SpeechRecognitionEvent(transcript, averageConfidence)); if (!transcript.isBlank()) {
} else { sttListener.sttEventReceived(new SpeechRecognitionEvent(transcript, averageConfidence));
if (!config.noResultsMessage.isBlank()) {
sttListener.sttEventReceived(new SpeechRecognitionErrorEvent(config.noResultsMessage));
} else { } else {
sttListener.sttEventReceived(new SpeechRecognitionErrorEvent("No results")); if (!config.noResultsMessage.isBlank()) {
sttListener.sttEventReceived(new SpeechRecognitionErrorEvent(config.noResultsMessage));
} else {
sttListener.sttEventReceived(new SpeechRecognitionErrorEvent("No results"));
}
} }
} }
} }