[googlestt] lazy abort (#12317)
Signed-off-by: Miguel Álvarez Díez <miguelwork92@gmail.com>
This commit is contained in:
parent
a517e6e768
commit
4d77608da1
|
@ -24,7 +24,6 @@ import java.util.Set;
|
|||
import java.util.concurrent.Future;
|
||||
import java.util.concurrent.ScheduledExecutorService;
|
||||
import java.util.concurrent.atomic.AtomicBoolean;
|
||||
import java.util.function.Consumer;
|
||||
|
||||
import org.eclipse.jdt.annotation.NonNullByDefault;
|
||||
import org.eclipse.jdt.annotation.Nullable;
|
||||
|
@ -147,17 +146,12 @@ public class GoogleSTTService implements STTService {
|
|||
@Override
|
||||
public STTServiceHandle recognize(STTListener sttListener, AudioStream audioStream, Locale locale,
|
||||
Set<String> set) {
|
||||
AtomicBoolean keepStreaming = new AtomicBoolean(true);
|
||||
Future scheduledTask = backgroundRecognize(sttListener, audioStream, keepStreaming, locale, set);
|
||||
AtomicBoolean aborted = new AtomicBoolean(false);
|
||||
backgroundRecognize(sttListener, audioStream, aborted, locale, set);
|
||||
return new STTServiceHandle() {
|
||||
@Override
|
||||
public void abort() {
|
||||
keepStreaming.set(false);
|
||||
try {
|
||||
Thread.sleep(100);
|
||||
} catch (InterruptedException e) {
|
||||
}
|
||||
scheduledTask.cancel(true);
|
||||
aborted.set(true);
|
||||
}
|
||||
};
|
||||
}
|
||||
|
@ -206,7 +200,7 @@ public class GoogleSTTService implements STTService {
|
|||
}
|
||||
}
|
||||
|
||||
private Future<?> backgroundRecognize(STTListener sttListener, AudioStream audioStream, AtomicBoolean keepStreaming,
|
||||
private Future<?> backgroundRecognize(STTListener sttListener, AudioStream audioStream, AtomicBoolean aborted,
|
||||
Locale locale, Set<String> set) {
|
||||
Credentials credentials = getCredentials();
|
||||
return executor.submit(() -> {
|
||||
|
@ -214,10 +208,9 @@ public class GoogleSTTService implements STTService {
|
|||
ClientStream<StreamingRecognizeRequest> clientStream = null;
|
||||
try (SpeechClient client = SpeechClient
|
||||
.create(SpeechSettings.newBuilder().setCredentialsProvider(() -> credentials).build())) {
|
||||
TranscriptionListener responseObserver = new TranscriptionListener(sttListener, config,
|
||||
(t) -> keepStreaming.set(false));
|
||||
TranscriptionListener responseObserver = new TranscriptionListener(sttListener, config, aborted);
|
||||
clientStream = client.streamingRecognizeCallable().splitCall(responseObserver);
|
||||
streamAudio(clientStream, audioStream, responseObserver, keepStreaming, locale);
|
||||
streamAudio(clientStream, audioStream, responseObserver, aborted, locale);
|
||||
clientStream.closeSend();
|
||||
logger.debug("Background recognize done");
|
||||
} catch (IOException e) {
|
||||
|
@ -232,7 +225,7 @@ public class GoogleSTTService implements STTService {
|
|||
}
|
||||
|
||||
private void streamAudio(ClientStream<StreamingRecognizeRequest> clientStream, AudioStream audioStream,
|
||||
TranscriptionListener responseObserver, AtomicBoolean keepStreaming, Locale locale) throws IOException {
|
||||
TranscriptionListener responseObserver, AtomicBoolean aborted, Locale locale) throws IOException {
|
||||
// Gather stream info and send config
|
||||
AudioFormat streamFormat = audioStream.getFormat();
|
||||
RecognitionConfig.AudioEncoding streamEncoding;
|
||||
|
@ -259,10 +252,14 @@ public class GoogleSTTService implements STTService {
|
|||
long maxTranscriptionMillis = (config.maxTranscriptionSeconds * 1000L);
|
||||
long maxSilenceMillis = (config.maxSilenceSeconds * 1000L);
|
||||
int readBytes = 6400;
|
||||
while (keepStreaming.get()) {
|
||||
while (!aborted.get()) {
|
||||
byte[] data = new byte[readBytes];
|
||||
int dataN = audioStream.read(data);
|
||||
if (!keepStreaming.get() || isExpiredInterval(maxTranscriptionMillis, startTime)) {
|
||||
if (aborted.get()) {
|
||||
logger.debug("Stops listening, aborted");
|
||||
break;
|
||||
}
|
||||
if (isExpiredInterval(maxTranscriptionMillis, startTime)) {
|
||||
logger.debug("Stops listening, max transcription time reached");
|
||||
break;
|
||||
}
|
||||
|
@ -328,16 +325,15 @@ public class GoogleSTTService implements STTService {
|
|||
private final StringBuilder transcriptBuilder = new StringBuilder();
|
||||
private final STTListener sttListener;
|
||||
GoogleSTTConfiguration config;
|
||||
private final Consumer<@Nullable Throwable> completeListener;
|
||||
private final AtomicBoolean aborted;
|
||||
private float confidenceSum = 0;
|
||||
private int responseCount = 0;
|
||||
private long lastInputTime = 0;
|
||||
|
||||
public TranscriptionListener(STTListener sttListener, GoogleSTTConfiguration config,
|
||||
Consumer<@Nullable Throwable> completeListener) {
|
||||
public TranscriptionListener(STTListener sttListener, GoogleSTTConfiguration config, AtomicBoolean aborted) {
|
||||
this.sttListener = sttListener;
|
||||
this.config = config;
|
||||
this.completeListener = completeListener;
|
||||
this.aborted = aborted;
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -372,7 +368,7 @@ public class GoogleSTTService implements STTService {
|
|||
responseCount++;
|
||||
// when in single utterance mode we can just get one final result so complete
|
||||
if (config.singleUtteranceMode) {
|
||||
completeListener.accept(null);
|
||||
onComplete();
|
||||
}
|
||||
}
|
||||
});
|
||||
|
@ -380,6 +376,7 @@ public class GoogleSTTService implements STTService {
|
|||
|
||||
@Override
|
||||
public void onComplete() {
|
||||
if (!aborted.getAndSet(true)) {
|
||||
sttListener.sttEventReceived(new RecognitionStopEvent());
|
||||
float averageConfidence = confidenceSum / responseCount;
|
||||
String transcript = transcriptBuilder.toString();
|
||||
|
@ -393,11 +390,12 @@ public class GoogleSTTService implements STTService {
|
|||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onError(@Nullable Throwable t) {
|
||||
logger.warn("Recognition error: ", t);
|
||||
completeListener.accept(t);
|
||||
if (!aborted.getAndSet(true)) {
|
||||
sttListener.sttEventReceived(new RecognitionStopEvent());
|
||||
if (!config.errorMessage.isBlank()) {
|
||||
sttListener.sttEventReceived(new SpeechRecognitionErrorEvent(config.errorMessage));
|
||||
|
@ -407,6 +405,7 @@ public class GoogleSTTService implements STTService {
|
|||
new SpeechRecognitionErrorEvent(errorMessage != null ? errorMessage : "Unknown error"));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public long getLastInputTime() {
|
||||
return lastInputTime;
|
||||
|
|
Loading…
Reference in New Issue