Adds streaming content
This commit is contained in:
@@ -1,3 +1,9 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use deadqueue::unlimited::Queue;
|
||||
use futures_util::stream::StreamExt;
|
||||
use reqwest_eventsource::{Event, EventSource};
|
||||
|
||||
use crate::dialogue::{Message, Role};
|
||||
use crate::error::{Error, Result};
|
||||
use crate::prelude::{
|
||||
@@ -47,6 +53,42 @@ impl<T: TokenProvider + Clone> GeminiClient<T> {
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn streaming_stream_generate_content(
|
||||
&self,
|
||||
request: &GenerateContentRequest,
|
||||
model: Model,
|
||||
) -> Arc<Queue<Option<ResponseStreamChunk>>> {
|
||||
let queue = Arc::new(Queue::<Option<ResponseStreamChunk>>::new());
|
||||
|
||||
// Clone the queue and other necessary data to move into the async block.
|
||||
let cloned_queue = queue.clone();
|
||||
let access_token = self.token_provider.get_token(AUTH_SCOPE).await.unwrap();
|
||||
let endpoint_url: String = format!(
|
||||
"https://{}/v1beta1/projects/{}/locations/{}/publishers/google/models/{}:streamGenerateContent?alt=sse", self.api_endpoint, self.project_id, self.location_id, model.to_string(),
|
||||
);
|
||||
let client = self.client.clone();
|
||||
let request = request.clone();
|
||||
|
||||
// Start a thread to run the request in the background.
|
||||
tokio::spawn(async move {
|
||||
let req = client
|
||||
.post(&endpoint_url)
|
||||
.bearer_auth(access_token)
|
||||
.json(&request);
|
||||
let mut event_source = EventSource::new(req).unwrap();
|
||||
while let Some(Ok(event)) = event_source.next().await {
|
||||
if let Event::Message(event) = event {
|
||||
let response: ResponseStreamChunk = serde_json::from_str(&event.data).unwrap();
|
||||
cloned_queue.push(Some(response));
|
||||
}
|
||||
}
|
||||
cloned_queue.push(None);
|
||||
});
|
||||
|
||||
// Return the queue that will receive the responses.
|
||||
queue
|
||||
}
|
||||
|
||||
pub async fn stream_generate_content(
|
||||
&self,
|
||||
request: &GenerateContentRequest,
|
||||
@@ -125,6 +167,12 @@ impl<T: TokenProvider + Clone> GeminiClient<T> {
|
||||
for chunk in response {
|
||||
match chunk {
|
||||
ResponseStreamChunk::Ok(ok_response) => {
|
||||
ok_response.candidates.iter().for_each(|c| {
|
||||
if let Some(t) = c.get_text() {
|
||||
text.push_str(&t);
|
||||
}
|
||||
});
|
||||
|
||||
for candidate in ok_response.candidates {
|
||||
if let Some(parts) = &candidate.content.parts {
|
||||
for part in parts {
|
||||
|
||||
Reference in New Issue
Block a user