Recognition Context
-
Cobalt Transcribe allows users to send context information with a recognition request which may aid the speech recognition. For example, if you have a list of names that you want to make sure the Transcribe model recognizes correctly, with the correct spelling, then you may provide the list in the form of a
RecognitionContextobject along with theRecognitionConfigbefore streaming data. -
Transcribe models allow different sets of “context tokens” each of which can be paired with a list of words or phrases. For example, a Transcribe model may have a context token for airport names, and you can provide a list of airport names you want to be recognized correctly for this context token. Likewise, models may also be configured with tokens for “contact list names”, “menu items”, “medical jargon” etc.
To ensure that there is no added latency in processing the list of words or
phrases during a recognition request, we have a API method called
CompileContext() that allows the user to
compile the list into a compact, efficient format for passing to the
StreamingRecognize() method.
Compiling Recognition Context
- The following snippet shows an example of how to compile context data and then send it during a recognition request.
import grpc
import cobaltspeech.transcribe.v5.transcribe_pb2_grpc as stub
import cobaltspeech.transcribe.v5.transcribe_pb2 as transcribe
serverAddress = "localhost:2727"
# Using a channel without TLS enabled.
channel = grpc.insecure_channel(serverAddress)
client = stub.TranscribeServiceStub(channel)
# Get server version.
versionResp = client.Version(transcribe.VersionRequest())
print(versionResp)
# Get list of models on the server.
modelResp = client.ListModels(transcribe.ListModelsRequest())
for model in modelResp.models:
print(model)
# Select a model ID from the list above. Going with the first model
# in this example. Also printing list of allowed context tokens.
m = modelResp.models[0]
print(f"context tokens = {m.attributes.context_info.allowed_context_tokens}")
# Let's say this model has an allowed context token called "airport_names" and
# we have a list of airport names that we want to make sure the recognizer gets
# right. We compile the list of names using the CompileContext(), save the compiled
# data and send it back with subsequent recognize requests to customize and improve
# the results.
#
# More typically, general models have a "catch-all" token called "unk:default" which
# can be used to boost the probabilities of any type word, as well as add words that
# are not in the model's vocabulary.
phrases = ["NARITA", "KUALA LUMPUR INTERNATIONAL", "ISTANBUL ATATURK", "LAGUARDIA"]
token = m.attributes.context_info.allowed_context_tokens[0] # "unk:default"
compileReq = transcribe.CompileContextRequest(
model_id=m.id,
token=token,
phrases=[ transcribe.ContextPhrase(text=t) for t in phrases ],
)
# Sending compilation request.
compiledResp = client.CompileContext(compileReq)
# Saving the compiled result for later use; note this compiled data is only
# compatible with the model whose ID was provided in the CompileContext call
compiledContexts = []
compiledContexts.append(compiledResp.context)
# Set the recognition config. We don't set the audio format and let the
# server auto-detect the format from the file header.
cfg = transcribe.RecognitionConfig(
model_id=m.id,
context=transcribe.RecognitionContext(compiled=compiledContexts),
)
# Open audio file.
audio = open("test.wav", "rb")
# The first request to the server should only contain the
# recognition configuration. Subsequent requests should contain
# audio bytes. We can write a simple generator to do this.
def stream(cfg, audio, bufferSize=1024):
yield transcribe.StreamingRecognizeRequest(config=cfg)
data = audio.read(bufferSize)
while len(data) > 0:
yield transcribe.StreamingRecognizeRequest(
audio=transcribe.RecognitionAudio(data=data),
)
data = audio.read(bufferSize)
# We also define a callback function to execute for each response.
# The example below just prints the formatted transcript to stdout.
def processResponse(resp):
result = resp.result
hyp = result.alternatives[0] # 1-best hypothesis.
transcript = hyp.transcript_formatted # Formatted transcript.
start = hyp.start_time_ms / 1000.0 # Converting to seconds.
end = start + hyp.duration_ms / 1000.0 # Converting to seconds.
newLine = "\r" if result.is_partial else "\n\n" # Will not move to new line for partial results.
print(f"[{start:0.2f}:{end:0.2f}] {transcript}", end=newLine)
# Streaming requests to the server.
for resp in client.StreamingRecognize(stream(cfg, audio)):
processResponse(resp)package main
import (
"context"
"errors"
"fmt"
"io"
"log"
"os"
"sync"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"
transcribe "github.com/cobaltspeech/go-genproto/cobaltspeech/transcribe/v5"
)
func main() {
const (
serverAddress = "localhost:2727"
)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
opts := []grpc.DialOption{
grpc.WithTransportCredentials(insecure.NewCredentials()), // Using a channel without TLS enabled.
grpc.WithBlock(),
grpc.WithReturnConnectionError(),
grpc.FailOnNonTempDialError(true),
}
conn, err := grpc.DialContext(ctx, serverAddress, opts...)
if err != nil {
fmt.Printf("failed to dial gRPC connection: %v\n", err)
os.Exit(1)
}
client := transcribe.NewTranscribeServiceClient(conn)
// Get server version.
versionResp, err := client.Version(ctx, &transcribe.VersionRequest{})
if err != nil {
fmt.Printf("failed to get server version: %v\n", err)
os.Exit(1)
}
fmt.Printf("%v\n", versionResp)
// Get list model of models on the server.
modelResp, err := client.ListModels(ctx, &transcribe.ListModelsRequest{})
if err != nil {
fmt.Printf("failed to get model list: %v\n", err)
os.Exit(1)
}
for _, m := range modelResp.Models {
fmt.Println(m)
}
fmt.Println()
// Select a model ID from the list above. Going with the first model
// in this example. Also printing list of allowed context tokens.
m := modelResp.Models[0]
fmt.Printf("context tokens = %v\n", m.Attributes.ContextInfo.AllowedContextTokens)
// Let's say this model has an allowed context token called "airport_names" and
// we have a list of airport names that we want to make sure the recognizer gets
// right. We compile the list of names using the CompileContext(), save the compiled
// data and send it back with subsequent recognize requests to customize and improve
// the results.
//
// More typically, general models have a "catch-all" token called "unk:default" which
// can be used to boost the probabilities of any type word, as well as add words that
// are not in the model's vocabulary.
phrases := []string{"NARITA", "KUALA LUMPUR INTERNATIONAL", "ISTANBUL ATATURK", "LAGUARDIA"}
token := m.Attributes.ContextInfo.AllowedContextTokens[0] // "unk:default"
compileReq := &transcribe.CompileContextRequest{
ModelId: m.Id,
Token: token,
Phrases: make([]*transcribe.ContextPhrase, 0, len(phrases)),
}
for _, t := range phrases {
compileReq.Phrases = append(compileReq.Phrases, &transcribe.ContextPhrase{
Text: t,
})
}
// Sending compilation request.
compiledResp, err := client.CompileContext(context.Background(), compileReq)
if err != nil {
log.Fatal(err)
}
// Saving the compiled result for later use; note this compiled data is only
// compatible with the model whose ID was provided in the CompileContext call
compiledContexts := []*transcribe.CompiledContext{compiledResp.Context}
// Set the recognition config. We don't set the audio format and let the
// server auto-detect the format from the file header.
cfg := &transcribe.RecognitionConfig{
ModelId: m.Id,
Context: &transcribe.RecognitionContext{
Compiled: compiledContexts,
},
}
// Opening audio file.
audio, err := os.Open("test.wav")
if err != nil {
fmt.Printf("failed to open audio file: %v\n", err)
os.Exit(1)
}
defer audio.Close()
// Starting recognition.
err = StreamingRecognize(ctx, client, cfg, audio, printTranscript)
if err != nil {
fmt.Printf("failed to run streaming recognition: %v\n", err)
os.Exit(1)
}
}
// StreamingRecognize wraps the bidirectional streaming API for performing
// speech recognition. It sets up recognition using the given cfg.
//
// Data is read from the given audio reader into a buffer and streamed to cubic
// server. The default buffer size may be overridden using Options when creating
// the Client.
//
// As results are received from Transcribe server, they will be sent to the
// provided handlerFunc.
//
// If any error occurs while reading the audio or sending it to the server, this
// method will immediately exit, returning that error.
//
// This function returns only after all results have been passed to the
// resultHandler.
func StreamingRecognize(
ctx context.Context,
client transcribe.TranscribeServiceClient,
cfg *transcribe.RecognitionConfig,
audio io.Reader,
handlerFunc func(*transcribe.StreamingRecognizeResponse),
) error {
const (
streamingBufSize = 1024
)
// Creating stream.
stream, err := client.StreamingRecognize(ctx)
if err != nil {
return err
}
// There are two concurrent processes going on. We will create a new
// goroutine to read audio and stream it to the server. This goroutine
// will receive results from the stream. Errors could occur in both
// go routines. We therefore setup a channel, errCh, to hold these
// errors. Both go routines are designed to send up to one error, and
// return immediately. Therefore we use a buffered channel with a
// capacity of two.
errCh := make(chan error, 2)
// start streaming audio in a separate goroutine
var wg sync.WaitGroup
wg.Add(1)
go func() {
if err := sendAudio(stream, cfg, audio, streamingBufSize); err != nil && !errors.Is(err, io.EOF) {
// if sendAudio encountered io.EOF, it's only a
// notification that the stream has closed. The actual
// status will be obtained in a subsequent Recv call, in
// the other goroutine below. We therefore only forward
// non-EOF errors.
errCh <- err
}
wg.Done()
}()
// Receive results from the stream.
for {
in, err := stream.Recv()
if errors.Is(err, io.EOF) {
break
}
if err != nil {
errCh <- err
break
}
handlerFunc(in)
}
wg.Wait()
select {
case err := <-errCh:
// There may be more than one error in the channel, but it is
// very likely they are related (e.g. connection reset causing
// both the send and recv to fail) and we therefore return the
// first error and discard the other.
return err
default:
return nil
}
}
// printTranscript is a callback function given to StreamingRecognize method to
// print results that are returned though the gRPC stream.
func printTranscript(resp *transcribe.StreamingRecognizeResponse) {
if resp.Error != nil {
fmt.Printf("\n[ERROR] server returned an error: %v\n", resp.Error)
return
}
hyp := resp.Result.Alternatives[0]
startTime := float32(hyp.StartTimeMs) / 1000.0
endTime := startTime + float32(hyp.DurationMs)/1000.0
if resp.Result.IsPartial {
fmt.Printf("\r[%0.2f:%0.2f] %s", startTime, endTime, hyp.TranscriptFormatted)
} else {
fmt.Printf("[%0.2f:%0.2f] %s\n\n", startTime, endTime, hyp.TranscriptFormatted)
}
}
// sendAudio sends audio to a stream.
func sendAudio(
stream transcribe.TranscribeService_StreamingRecognizeClient,
cfg *transcribe.RecognitionConfig,
audio io.Reader,
bufSize uint32,
) error {
// The first message needs to be a config message, and all subsequent
// messages must be audio messages.
// Send the recognition config
if err := stream.Send(&transcribe.StreamingRecognizeRequest{
Request: &transcribe.StreamingRecognizeRequest_Config{Config: cfg},
}); err != nil {
// if this failed, we don't need to CloseSend
return err
}
// Stream the audio.
buf := make([]byte, bufSize)
for {
n, err := audio.Read(buf)
if n > 0 {
if err2 := stream.Send(&transcribe.StreamingRecognizeRequest{
Request: &transcribe.StreamingRecognizeRequest_Audio{
Audio: &transcribe.RecognitionAudio{Data: buf[:n]},
},
}); err2 != nil {
// if we couldn't Send, the stream has
// encountered an error and we don't need to
// CloseSend.
return err2
}
}
if err != nil {
// err could be io.EOF, or some other error reading from
// audio. In any case, we need to CloseSend, send the
// appropriate error to errCh and return from the function
if err2 := stream.CloseSend(); err2 != nil {
return err2
}
if err != io.EOF {
return err
}
return nil
}
}
}