[googlestt] lazy abort (#12317)
Signed-off-by: Miguel Álvarez Díez <miguelwork92@gmail.com>
This commit is contained in:
parent
a517e6e768
commit
4d77608da1
|
@ -24,7 +24,6 @@ import java.util.Set;
|
||||||
import java.util.concurrent.Future;
|
import java.util.concurrent.Future;
|
||||||
import java.util.concurrent.ScheduledExecutorService;
|
import java.util.concurrent.ScheduledExecutorService;
|
||||||
import java.util.concurrent.atomic.AtomicBoolean;
|
import java.util.concurrent.atomic.AtomicBoolean;
|
||||||
import java.util.function.Consumer;
|
|
||||||
|
|
||||||
import org.eclipse.jdt.annotation.NonNullByDefault;
|
import org.eclipse.jdt.annotation.NonNullByDefault;
|
||||||
import org.eclipse.jdt.annotation.Nullable;
|
import org.eclipse.jdt.annotation.Nullable;
|
||||||
|
@ -147,17 +146,12 @@ public class GoogleSTTService implements STTService {
|
||||||
@Override
|
@Override
|
||||||
public STTServiceHandle recognize(STTListener sttListener, AudioStream audioStream, Locale locale,
|
public STTServiceHandle recognize(STTListener sttListener, AudioStream audioStream, Locale locale,
|
||||||
Set<String> set) {
|
Set<String> set) {
|
||||||
AtomicBoolean keepStreaming = new AtomicBoolean(true);
|
AtomicBoolean aborted = new AtomicBoolean(false);
|
||||||
Future scheduledTask = backgroundRecognize(sttListener, audioStream, keepStreaming, locale, set);
|
backgroundRecognize(sttListener, audioStream, aborted, locale, set);
|
||||||
return new STTServiceHandle() {
|
return new STTServiceHandle() {
|
||||||
@Override
|
@Override
|
||||||
public void abort() {
|
public void abort() {
|
||||||
keepStreaming.set(false);
|
aborted.set(true);
|
||||||
try {
|
|
||||||
Thread.sleep(100);
|
|
||||||
} catch (InterruptedException e) {
|
|
||||||
}
|
|
||||||
scheduledTask.cancel(true);
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
@ -206,7 +200,7 @@ public class GoogleSTTService implements STTService {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private Future<?> backgroundRecognize(STTListener sttListener, AudioStream audioStream, AtomicBoolean keepStreaming,
|
private Future<?> backgroundRecognize(STTListener sttListener, AudioStream audioStream, AtomicBoolean aborted,
|
||||||
Locale locale, Set<String> set) {
|
Locale locale, Set<String> set) {
|
||||||
Credentials credentials = getCredentials();
|
Credentials credentials = getCredentials();
|
||||||
return executor.submit(() -> {
|
return executor.submit(() -> {
|
||||||
|
@ -214,10 +208,9 @@ public class GoogleSTTService implements STTService {
|
||||||
ClientStream<StreamingRecognizeRequest> clientStream = null;
|
ClientStream<StreamingRecognizeRequest> clientStream = null;
|
||||||
try (SpeechClient client = SpeechClient
|
try (SpeechClient client = SpeechClient
|
||||||
.create(SpeechSettings.newBuilder().setCredentialsProvider(() -> credentials).build())) {
|
.create(SpeechSettings.newBuilder().setCredentialsProvider(() -> credentials).build())) {
|
||||||
TranscriptionListener responseObserver = new TranscriptionListener(sttListener, config,
|
TranscriptionListener responseObserver = new TranscriptionListener(sttListener, config, aborted);
|
||||||
(t) -> keepStreaming.set(false));
|
|
||||||
clientStream = client.streamingRecognizeCallable().splitCall(responseObserver);
|
clientStream = client.streamingRecognizeCallable().splitCall(responseObserver);
|
||||||
streamAudio(clientStream, audioStream, responseObserver, keepStreaming, locale);
|
streamAudio(clientStream, audioStream, responseObserver, aborted, locale);
|
||||||
clientStream.closeSend();
|
clientStream.closeSend();
|
||||||
logger.debug("Background recognize done");
|
logger.debug("Background recognize done");
|
||||||
} catch (IOException e) {
|
} catch (IOException e) {
|
||||||
|
@ -232,7 +225,7 @@ public class GoogleSTTService implements STTService {
|
||||||
}
|
}
|
||||||
|
|
||||||
private void streamAudio(ClientStream<StreamingRecognizeRequest> clientStream, AudioStream audioStream,
|
private void streamAudio(ClientStream<StreamingRecognizeRequest> clientStream, AudioStream audioStream,
|
||||||
TranscriptionListener responseObserver, AtomicBoolean keepStreaming, Locale locale) throws IOException {
|
TranscriptionListener responseObserver, AtomicBoolean aborted, Locale locale) throws IOException {
|
||||||
// Gather stream info and send config
|
// Gather stream info and send config
|
||||||
AudioFormat streamFormat = audioStream.getFormat();
|
AudioFormat streamFormat = audioStream.getFormat();
|
||||||
RecognitionConfig.AudioEncoding streamEncoding;
|
RecognitionConfig.AudioEncoding streamEncoding;
|
||||||
|
@ -259,10 +252,14 @@ public class GoogleSTTService implements STTService {
|
||||||
long maxTranscriptionMillis = (config.maxTranscriptionSeconds * 1000L);
|
long maxTranscriptionMillis = (config.maxTranscriptionSeconds * 1000L);
|
||||||
long maxSilenceMillis = (config.maxSilenceSeconds * 1000L);
|
long maxSilenceMillis = (config.maxSilenceSeconds * 1000L);
|
||||||
int readBytes = 6400;
|
int readBytes = 6400;
|
||||||
while (keepStreaming.get()) {
|
while (!aborted.get()) {
|
||||||
byte[] data = new byte[readBytes];
|
byte[] data = new byte[readBytes];
|
||||||
int dataN = audioStream.read(data);
|
int dataN = audioStream.read(data);
|
||||||
if (!keepStreaming.get() || isExpiredInterval(maxTranscriptionMillis, startTime)) {
|
if (aborted.get()) {
|
||||||
|
logger.debug("Stops listening, aborted");
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
if (isExpiredInterval(maxTranscriptionMillis, startTime)) {
|
||||||
logger.debug("Stops listening, max transcription time reached");
|
logger.debug("Stops listening, max transcription time reached");
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
@ -328,16 +325,15 @@ public class GoogleSTTService implements STTService {
|
||||||
private final StringBuilder transcriptBuilder = new StringBuilder();
|
private final StringBuilder transcriptBuilder = new StringBuilder();
|
||||||
private final STTListener sttListener;
|
private final STTListener sttListener;
|
||||||
GoogleSTTConfiguration config;
|
GoogleSTTConfiguration config;
|
||||||
private final Consumer<@Nullable Throwable> completeListener;
|
private final AtomicBoolean aborted;
|
||||||
private float confidenceSum = 0;
|
private float confidenceSum = 0;
|
||||||
private int responseCount = 0;
|
private int responseCount = 0;
|
||||||
private long lastInputTime = 0;
|
private long lastInputTime = 0;
|
||||||
|
|
||||||
public TranscriptionListener(STTListener sttListener, GoogleSTTConfiguration config,
|
public TranscriptionListener(STTListener sttListener, GoogleSTTConfiguration config, AtomicBoolean aborted) {
|
||||||
Consumer<@Nullable Throwable> completeListener) {
|
|
||||||
this.sttListener = sttListener;
|
this.sttListener = sttListener;
|
||||||
this.config = config;
|
this.config = config;
|
||||||
this.completeListener = completeListener;
|
this.aborted = aborted;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -372,7 +368,7 @@ public class GoogleSTTService implements STTService {
|
||||||
responseCount++;
|
responseCount++;
|
||||||
// when in single utterance mode we can just get one final result so complete
|
// when in single utterance mode we can just get one final result so complete
|
||||||
if (config.singleUtteranceMode) {
|
if (config.singleUtteranceMode) {
|
||||||
completeListener.accept(null);
|
onComplete();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
@ -380,16 +376,18 @@ public class GoogleSTTService implements STTService {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void onComplete() {
|
public void onComplete() {
|
||||||
sttListener.sttEventReceived(new RecognitionStopEvent());
|
if (!aborted.getAndSet(true)) {
|
||||||
float averageConfidence = confidenceSum / responseCount;
|
sttListener.sttEventReceived(new RecognitionStopEvent());
|
||||||
String transcript = transcriptBuilder.toString();
|
float averageConfidence = confidenceSum / responseCount;
|
||||||
if (!transcript.isBlank()) {
|
String transcript = transcriptBuilder.toString();
|
||||||
sttListener.sttEventReceived(new SpeechRecognitionEvent(transcript, averageConfidence));
|
if (!transcript.isBlank()) {
|
||||||
} else {
|
sttListener.sttEventReceived(new SpeechRecognitionEvent(transcript, averageConfidence));
|
||||||
if (!config.noResultsMessage.isBlank()) {
|
|
||||||
sttListener.sttEventReceived(new SpeechRecognitionErrorEvent(config.noResultsMessage));
|
|
||||||
} else {
|
} else {
|
||||||
sttListener.sttEventReceived(new SpeechRecognitionErrorEvent("No results"));
|
if (!config.noResultsMessage.isBlank()) {
|
||||||
|
sttListener.sttEventReceived(new SpeechRecognitionErrorEvent(config.noResultsMessage));
|
||||||
|
} else {
|
||||||
|
sttListener.sttEventReceived(new SpeechRecognitionErrorEvent("No results"));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -397,14 +395,15 @@ public class GoogleSTTService implements STTService {
|
||||||
@Override
|
@Override
|
||||||
public void onError(@Nullable Throwable t) {
|
public void onError(@Nullable Throwable t) {
|
||||||
logger.warn("Recognition error: ", t);
|
logger.warn("Recognition error: ", t);
|
||||||
completeListener.accept(t);
|
if (!aborted.getAndSet(true)) {
|
||||||
sttListener.sttEventReceived(new RecognitionStopEvent());
|
sttListener.sttEventReceived(new RecognitionStopEvent());
|
||||||
if (!config.errorMessage.isBlank()) {
|
if (!config.errorMessage.isBlank()) {
|
||||||
sttListener.sttEventReceived(new SpeechRecognitionErrorEvent(config.errorMessage));
|
sttListener.sttEventReceived(new SpeechRecognitionErrorEvent(config.errorMessage));
|
||||||
} else {
|
} else {
|
||||||
String errorMessage = t.getMessage();
|
String errorMessage = t.getMessage();
|
||||||
sttListener.sttEventReceived(
|
sttListener.sttEventReceived(
|
||||||
new SpeechRecognitionErrorEvent(errorMessage != null ? errorMessage : "Unknown error"));
|
new SpeechRecognitionErrorEvent(errorMessage != null ? errorMessage : "Unknown error"));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue