[googlestt] lazy abort (#12317)

Signed-off-by: Miguel Álvarez Díez <miguelwork92@gmail.com>
This commit is contained in:
GiviMAD 2022-02-20 12:45:31 +01:00 committed by GitHub
parent a517e6e768
commit 4d77608da1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 37 additions and 38 deletions

View File

@ -24,7 +24,6 @@ import java.util.Set;
import java.util.concurrent.Future; import java.util.concurrent.Future;
import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.Consumer;
import org.eclipse.jdt.annotation.NonNullByDefault; import org.eclipse.jdt.annotation.NonNullByDefault;
import org.eclipse.jdt.annotation.Nullable; import org.eclipse.jdt.annotation.Nullable;
@ -147,17 +146,12 @@ public class GoogleSTTService implements STTService {
@Override @Override
public STTServiceHandle recognize(STTListener sttListener, AudioStream audioStream, Locale locale, public STTServiceHandle recognize(STTListener sttListener, AudioStream audioStream, Locale locale,
Set<String> set) { Set<String> set) {
AtomicBoolean keepStreaming = new AtomicBoolean(true); AtomicBoolean aborted = new AtomicBoolean(false);
Future scheduledTask = backgroundRecognize(sttListener, audioStream, keepStreaming, locale, set); backgroundRecognize(sttListener, audioStream, aborted, locale, set);
return new STTServiceHandle() { return new STTServiceHandle() {
@Override @Override
public void abort() { public void abort() {
keepStreaming.set(false); aborted.set(true);
try {
Thread.sleep(100);
} catch (InterruptedException e) {
}
scheduledTask.cancel(true);
} }
}; };
} }
@ -206,7 +200,7 @@ public class GoogleSTTService implements STTService {
} }
} }
private Future<?> backgroundRecognize(STTListener sttListener, AudioStream audioStream, AtomicBoolean keepStreaming, private Future<?> backgroundRecognize(STTListener sttListener, AudioStream audioStream, AtomicBoolean aborted,
Locale locale, Set<String> set) { Locale locale, Set<String> set) {
Credentials credentials = getCredentials(); Credentials credentials = getCredentials();
return executor.submit(() -> { return executor.submit(() -> {
@ -214,10 +208,9 @@ public class GoogleSTTService implements STTService {
ClientStream<StreamingRecognizeRequest> clientStream = null; ClientStream<StreamingRecognizeRequest> clientStream = null;
try (SpeechClient client = SpeechClient try (SpeechClient client = SpeechClient
.create(SpeechSettings.newBuilder().setCredentialsProvider(() -> credentials).build())) { .create(SpeechSettings.newBuilder().setCredentialsProvider(() -> credentials).build())) {
TranscriptionListener responseObserver = new TranscriptionListener(sttListener, config, TranscriptionListener responseObserver = new TranscriptionListener(sttListener, config, aborted);
(t) -> keepStreaming.set(false));
clientStream = client.streamingRecognizeCallable().splitCall(responseObserver); clientStream = client.streamingRecognizeCallable().splitCall(responseObserver);
streamAudio(clientStream, audioStream, responseObserver, keepStreaming, locale); streamAudio(clientStream, audioStream, responseObserver, aborted, locale);
clientStream.closeSend(); clientStream.closeSend();
logger.debug("Background recognize done"); logger.debug("Background recognize done");
} catch (IOException e) { } catch (IOException e) {
@ -232,7 +225,7 @@ public class GoogleSTTService implements STTService {
} }
private void streamAudio(ClientStream<StreamingRecognizeRequest> clientStream, AudioStream audioStream, private void streamAudio(ClientStream<StreamingRecognizeRequest> clientStream, AudioStream audioStream,
TranscriptionListener responseObserver, AtomicBoolean keepStreaming, Locale locale) throws IOException { TranscriptionListener responseObserver, AtomicBoolean aborted, Locale locale) throws IOException {
// Gather stream info and send config // Gather stream info and send config
AudioFormat streamFormat = audioStream.getFormat(); AudioFormat streamFormat = audioStream.getFormat();
RecognitionConfig.AudioEncoding streamEncoding; RecognitionConfig.AudioEncoding streamEncoding;
@ -259,10 +252,14 @@ public class GoogleSTTService implements STTService {
long maxTranscriptionMillis = (config.maxTranscriptionSeconds * 1000L); long maxTranscriptionMillis = (config.maxTranscriptionSeconds * 1000L);
long maxSilenceMillis = (config.maxSilenceSeconds * 1000L); long maxSilenceMillis = (config.maxSilenceSeconds * 1000L);
int readBytes = 6400; int readBytes = 6400;
while (keepStreaming.get()) { while (!aborted.get()) {
byte[] data = new byte[readBytes]; byte[] data = new byte[readBytes];
int dataN = audioStream.read(data); int dataN = audioStream.read(data);
if (!keepStreaming.get() || isExpiredInterval(maxTranscriptionMillis, startTime)) { if (aborted.get()) {
logger.debug("Stops listening, aborted");
break;
}
if (isExpiredInterval(maxTranscriptionMillis, startTime)) {
logger.debug("Stops listening, max transcription time reached"); logger.debug("Stops listening, max transcription time reached");
break; break;
} }
@ -328,16 +325,15 @@ public class GoogleSTTService implements STTService {
private final StringBuilder transcriptBuilder = new StringBuilder(); private final StringBuilder transcriptBuilder = new StringBuilder();
private final STTListener sttListener; private final STTListener sttListener;
GoogleSTTConfiguration config; GoogleSTTConfiguration config;
private final Consumer<@Nullable Throwable> completeListener; private final AtomicBoolean aborted;
private float confidenceSum = 0; private float confidenceSum = 0;
private int responseCount = 0; private int responseCount = 0;
private long lastInputTime = 0; private long lastInputTime = 0;
public TranscriptionListener(STTListener sttListener, GoogleSTTConfiguration config, public TranscriptionListener(STTListener sttListener, GoogleSTTConfiguration config, AtomicBoolean aborted) {
Consumer<@Nullable Throwable> completeListener) {
this.sttListener = sttListener; this.sttListener = sttListener;
this.config = config; this.config = config;
this.completeListener = completeListener; this.aborted = aborted;
} }
@Override @Override
@ -372,7 +368,7 @@ public class GoogleSTTService implements STTService {
responseCount++; responseCount++;
// when in single utterance mode we can just get one final result so complete // when in single utterance mode we can just get one final result so complete
if (config.singleUtteranceMode) { if (config.singleUtteranceMode) {
completeListener.accept(null); onComplete();
} }
} }
}); });
@ -380,16 +376,18 @@ public class GoogleSTTService implements STTService {
@Override @Override
public void onComplete() { public void onComplete() {
sttListener.sttEventReceived(new RecognitionStopEvent()); if (!aborted.getAndSet(true)) {
float averageConfidence = confidenceSum / responseCount; sttListener.sttEventReceived(new RecognitionStopEvent());
String transcript = transcriptBuilder.toString(); float averageConfidence = confidenceSum / responseCount;
if (!transcript.isBlank()) { String transcript = transcriptBuilder.toString();
sttListener.sttEventReceived(new SpeechRecognitionEvent(transcript, averageConfidence)); if (!transcript.isBlank()) {
} else { sttListener.sttEventReceived(new SpeechRecognitionEvent(transcript, averageConfidence));
if (!config.noResultsMessage.isBlank()) {
sttListener.sttEventReceived(new SpeechRecognitionErrorEvent(config.noResultsMessage));
} else { } else {
sttListener.sttEventReceived(new SpeechRecognitionErrorEvent("No results")); if (!config.noResultsMessage.isBlank()) {
sttListener.sttEventReceived(new SpeechRecognitionErrorEvent(config.noResultsMessage));
} else {
sttListener.sttEventReceived(new SpeechRecognitionErrorEvent("No results"));
}
} }
} }
} }
@ -397,14 +395,15 @@ public class GoogleSTTService implements STTService {
@Override @Override
public void onError(@Nullable Throwable t) { public void onError(@Nullable Throwable t) {
logger.warn("Recognition error: ", t); logger.warn("Recognition error: ", t);
completeListener.accept(t); if (!aborted.getAndSet(true)) {
sttListener.sttEventReceived(new RecognitionStopEvent()); sttListener.sttEventReceived(new RecognitionStopEvent());
if (!config.errorMessage.isBlank()) { if (!config.errorMessage.isBlank()) {
sttListener.sttEventReceived(new SpeechRecognitionErrorEvent(config.errorMessage)); sttListener.sttEventReceived(new SpeechRecognitionErrorEvent(config.errorMessage));
} else { } else {
String errorMessage = t.getMessage(); String errorMessage = t.getMessage();
sttListener.sttEventReceived( sttListener.sttEventReceived(
new SpeechRecognitionErrorEvent(errorMessage != null ? errorMessage : "Unknown error")); new SpeechRecognitionErrorEvent(errorMessage != null ? errorMessage : "Unknown error"));
}
} }
} }