[watsonstt] lazy abort (#12318)

* [watsonstt] lazy abort

Signed-off-by: Miguel Álvarez Díez <miguelwork92@gmail.com>
This commit is contained in:
GiviMAD 2022-02-20 14:25:47 +01:00 committed by GitHub
parent c03dbfd520
commit 27538fa87f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 31 additions and 22 deletions

View File

@ -19,6 +19,7 @@ import java.util.Locale;
import java.util.Map; import java.util.Map;
import java.util.Set; import java.util.Set;
import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicReference; import java.util.concurrent.atomic.AtomicReference;
import java.util.stream.Collectors; import java.util.stream.Collectors;
@ -132,12 +133,13 @@ public class WatsonSTTService implements STTService {
.speechDetectorSensitivity(config.speechDetectorSensitivity).inactivityTimeout(config.inactivityTimeout) .speechDetectorSensitivity(config.speechDetectorSensitivity).inactivityTimeout(config.inactivityTimeout)
.build(); .build();
final AtomicReference<@Nullable WebSocket> socketRef = new AtomicReference<>(); final AtomicReference<@Nullable WebSocket> socketRef = new AtomicReference<>();
var task = executor.submit(() -> { final AtomicBoolean aborted = new AtomicBoolean(false);
executor.submit(() -> {
int retries = 2; int retries = 2;
while (retries > 0) { while (retries > 0) {
try { try {
socketRef.set(speechToText.recognizeUsingWebSocket(wsOptions, socketRef.set(speechToText.recognizeUsingWebSocket(wsOptions,
new TranscriptionListener(sttListener, config))); new TranscriptionListener(sttListener, config, aborted)));
break; break;
} catch (RuntimeException e) { } catch (RuntimeException e) {
var cause = e.getCause(); var cause = e.getCause();
@ -157,6 +159,7 @@ public class WatsonSTTService implements STTService {
return new STTServiceHandle() { return new STTServiceHandle() {
@Override @Override
public void abort() { public void abort() {
if (!aborted.getAndSet(true)) {
var socket = socketRef.get(); var socket = socketRef.get();
if (socket != null) { if (socket != null) {
socket.close(1000, null); socket.close(1000, null);
@ -166,7 +169,7 @@ public class WatsonSTTService implements STTService {
} catch (InterruptedException ignored) { } catch (InterruptedException ignored) {
} }
} }
task.cancel(true); }
} }
}; };
} }
@ -226,13 +229,15 @@ public class WatsonSTTService implements STTService {
private final StringBuilder transcriptBuilder = new StringBuilder(); private final StringBuilder transcriptBuilder = new StringBuilder();
private final STTListener sttListener; private final STTListener sttListener;
private final WatsonSTTConfiguration config; private final WatsonSTTConfiguration config;
private final AtomicBoolean aborted;
private float confidenceSum = 0f; private float confidenceSum = 0f;
private int responseCount = 0; private int responseCount = 0;
private boolean disconnected = false; private boolean disconnected = false;
public TranscriptionListener(STTListener sttListener, WatsonSTTConfiguration config) { public TranscriptionListener(STTListener sttListener, WatsonSTTConfiguration config, AtomicBoolean aborted) {
this.sttListener = sttListener; this.sttListener = sttListener;
this.config = config; this.config = config;
this.aborted = aborted;
} }
@Override @Override
@ -267,14 +272,17 @@ public class WatsonSTTService implements STTService {
return; return;
} }
logger.warn("TranscriptionError: {}", errorMessage); logger.warn("TranscriptionError: {}", errorMessage);
if (!aborted.get()) {
sttListener.sttEventReceived( sttListener.sttEventReceived(
new SpeechRecognitionErrorEvent(errorMessage != null ? errorMessage : "Unknown error")); new SpeechRecognitionErrorEvent(errorMessage != null ? errorMessage : "Unknown error"));
} }
}
@Override @Override
public void onDisconnected() { public void onDisconnected() {
logger.debug("onDisconnected"); logger.debug("onDisconnected");
disconnected = true; disconnected = true;
if (!aborted.getAndSet(true)) {
sttListener.sttEventReceived(new RecognitionStopEvent()); sttListener.sttEventReceived(new RecognitionStopEvent());
float averageConfidence = confidenceSum / (float) responseCount; float averageConfidence = confidenceSum / (float) responseCount;
String transcript = transcriptBuilder.toString(); String transcript = transcriptBuilder.toString();
@ -288,6 +296,7 @@ public class WatsonSTTService implements STTService {
} }
} }
} }
}
@Override @Override
public void onInactivityTimeout(@Nullable RuntimeException e) { public void onInactivityTimeout(@Nullable RuntimeException e) {