[googlestt] lazy abort (#12317)

Signed-off-by: Miguel Álvarez Díez <miguelwork92@gmail.com>
This commit is contained in:
GiviMAD 2022-02-20 12:45:31 +01:00 committed by GitHub
parent a517e6e768
commit 4d77608da1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 37 additions and 38 deletions

View File

@ -24,7 +24,6 @@ import java.util.Set;
import java.util.concurrent.Future;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.Consumer;
import org.eclipse.jdt.annotation.NonNullByDefault;
import org.eclipse.jdt.annotation.Nullable;
@ -147,17 +146,12 @@ public class GoogleSTTService implements STTService {
@Override
public STTServiceHandle recognize(STTListener sttListener, AudioStream audioStream, Locale locale,
Set<String> set) {
AtomicBoolean keepStreaming = new AtomicBoolean(true);
Future scheduledTask = backgroundRecognize(sttListener, audioStream, keepStreaming, locale, set);
AtomicBoolean aborted = new AtomicBoolean(false);
backgroundRecognize(sttListener, audioStream, aborted, locale, set);
return new STTServiceHandle() {
@Override
public void abort() {
keepStreaming.set(false);
try {
Thread.sleep(100);
} catch (InterruptedException e) {
}
scheduledTask.cancel(true);
aborted.set(true);
}
};
}
@ -206,7 +200,7 @@ public class GoogleSTTService implements STTService {
}
}
private Future<?> backgroundRecognize(STTListener sttListener, AudioStream audioStream, AtomicBoolean keepStreaming,
private Future<?> backgroundRecognize(STTListener sttListener, AudioStream audioStream, AtomicBoolean aborted,
Locale locale, Set<String> set) {
Credentials credentials = getCredentials();
return executor.submit(() -> {
@ -214,10 +208,9 @@ public class GoogleSTTService implements STTService {
ClientStream<StreamingRecognizeRequest> clientStream = null;
try (SpeechClient client = SpeechClient
.create(SpeechSettings.newBuilder().setCredentialsProvider(() -> credentials).build())) {
TranscriptionListener responseObserver = new TranscriptionListener(sttListener, config,
(t) -> keepStreaming.set(false));
TranscriptionListener responseObserver = new TranscriptionListener(sttListener, config, aborted);
clientStream = client.streamingRecognizeCallable().splitCall(responseObserver);
streamAudio(clientStream, audioStream, responseObserver, keepStreaming, locale);
streamAudio(clientStream, audioStream, responseObserver, aborted, locale);
clientStream.closeSend();
logger.debug("Background recognize done");
} catch (IOException e) {
@ -232,7 +225,7 @@ public class GoogleSTTService implements STTService {
}
private void streamAudio(ClientStream<StreamingRecognizeRequest> clientStream, AudioStream audioStream,
TranscriptionListener responseObserver, AtomicBoolean keepStreaming, Locale locale) throws IOException {
TranscriptionListener responseObserver, AtomicBoolean aborted, Locale locale) throws IOException {
// Gather stream info and send config
AudioFormat streamFormat = audioStream.getFormat();
RecognitionConfig.AudioEncoding streamEncoding;
@ -259,10 +252,14 @@ public class GoogleSTTService implements STTService {
long maxTranscriptionMillis = (config.maxTranscriptionSeconds * 1000L);
long maxSilenceMillis = (config.maxSilenceSeconds * 1000L);
int readBytes = 6400;
while (keepStreaming.get()) {
while (!aborted.get()) {
byte[] data = new byte[readBytes];
int dataN = audioStream.read(data);
if (!keepStreaming.get() || isExpiredInterval(maxTranscriptionMillis, startTime)) {
if (aborted.get()) {
logger.debug("Stops listening, aborted");
break;
}
if (isExpiredInterval(maxTranscriptionMillis, startTime)) {
logger.debug("Stops listening, max transcription time reached");
break;
}
@ -328,16 +325,15 @@ public class GoogleSTTService implements STTService {
private final StringBuilder transcriptBuilder = new StringBuilder();
private final STTListener sttListener;
GoogleSTTConfiguration config;
private final Consumer<@Nullable Throwable> completeListener;
private final AtomicBoolean aborted;
private float confidenceSum = 0;
private int responseCount = 0;
private long lastInputTime = 0;
public TranscriptionListener(STTListener sttListener, GoogleSTTConfiguration config,
Consumer<@Nullable Throwable> completeListener) {
public TranscriptionListener(STTListener sttListener, GoogleSTTConfiguration config, AtomicBoolean aborted) {
this.sttListener = sttListener;
this.config = config;
this.completeListener = completeListener;
this.aborted = aborted;
}
@Override
@ -372,7 +368,7 @@ public class GoogleSTTService implements STTService {
responseCount++;
// when in single utterance mode we can just get one final result so complete
if (config.singleUtteranceMode) {
completeListener.accept(null);
onComplete();
}
}
});
@ -380,6 +376,7 @@ public class GoogleSTTService implements STTService {
@Override
public void onComplete() {
if (!aborted.getAndSet(true)) {
sttListener.sttEventReceived(new RecognitionStopEvent());
float averageConfidence = confidenceSum / responseCount;
String transcript = transcriptBuilder.toString();
@ -393,11 +390,12 @@ public class GoogleSTTService implements STTService {
}
}
}
}
@Override
public void onError(@Nullable Throwable t) {
logger.warn("Recognition error: ", t);
completeListener.accept(t);
if (!aborted.getAndSet(true)) {
sttListener.sttEventReceived(new RecognitionStopEvent());
if (!config.errorMessage.isBlank()) {
sttListener.sttEventReceived(new SpeechRecognitionErrorEvent(config.errorMessage));
@ -407,6 +405,7 @@ public class GoogleSTTService implements STTService {
new SpeechRecognitionErrorEvent(errorMessage != null ? errorMessage : "Unknown error"));
}
}
}
public long getLastInputTime() {
return lastInputTime;