Adds streaming content

This commit is contained in:
2024-02-16 21:36:08 +00:00
parent cc78aea3ff
commit 83335e65dc
4 changed files with 131 additions and 7 deletions

View File

@@ -6,11 +6,15 @@ edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies] [dependencies]
deadqueue = "0.2.4"
futures-util = "0.3.30"
gcp_auth = "0.10.0" gcp_auth = "0.10.0"
reqwest = { version = "0.11.9", features = ["json", "gzip"] } reqwest = { version = "0.11.9", features = ["json", "gzip"] }
reqwest-eventsource = "0.5.0"
serde = { version = "*", features = ["derive"] } serde = { version = "*", features = ["derive"] }
serde_json = { version = "*"} serde_json = { version = "*"}
tracing = "0.1.40" tracing = "0.1.40"
tokio = { version = "1.36.0" }
[dev-dependencies] [dev-dependencies]
console = "0.15.8" console = "0.15.8"

View File

@@ -0,0 +1,39 @@
use std::sync::Arc;
use gemini_rs::prelude::*;
use gcp_auth::AuthenticationManager;
#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
let authentication_manager = Arc::new(AuthenticationManager::new().await?);
let api_endpoint = std::env::var("API_ENDPOINT")?;
let project_id = std::env::var("PROJECT_ID")?;
let location_id = std::env::var("LOCATION_ID")?;
let gemini = GeminiClient::new(
authentication_manager,
api_endpoint,
project_id,
location_id,
);
let prompt = "Tell me the story of the genesis of the universe as a bedtime story.";
let request = GenerateContentRequest::from_prompt(prompt, None);
let queue = gemini
.streaming_stream_generate_content(&request, Model::GeminiPro)
.await;
while let Some(chunk) = queue.pop().await {
if let ResponseStreamChunk::Ok(ok_response) = chunk {
let text = ok_response
.candidates
.iter()
.filter_map(|c| c.get_text())
.collect::<String>();
print!("{}", text);
}
}
Ok(())
}

View File

@@ -1,3 +1,9 @@
use std::sync::Arc;
use deadqueue::unlimited::Queue;
use futures_util::stream::StreamExt;
use reqwest_eventsource::{Event, EventSource};
use crate::dialogue::{Message, Role}; use crate::dialogue::{Message, Role};
use crate::error::{Error, Result}; use crate::error::{Error, Result};
use crate::prelude::{ use crate::prelude::{
@@ -47,6 +53,42 @@ impl<T: TokenProvider + Clone> GeminiClient<T> {
} }
} }
pub async fn streaming_stream_generate_content(
&self,
request: &GenerateContentRequest,
model: Model,
) -> Arc<Queue<Option<ResponseStreamChunk>>> {
let queue = Arc::new(Queue::<Option<ResponseStreamChunk>>::new());
// Clone the queue and other necessary data to move into the async block.
let cloned_queue = queue.clone();
let access_token = self.token_provider.get_token(AUTH_SCOPE).await.unwrap();
let endpoint_url: String = format!(
"https://{}/v1beta1/projects/{}/locations/{}/publishers/google/models/{}:streamGenerateContent?alt=sse", self.api_endpoint, self.project_id, self.location_id, model.to_string(),
);
let client = self.client.clone();
let request = request.clone();
// Start a thread to run the request in the background.
tokio::spawn(async move {
let req = client
.post(&endpoint_url)
.bearer_auth(access_token)
.json(&request);
let mut event_source = EventSource::new(req).unwrap();
while let Some(Ok(event)) = event_source.next().await {
if let Event::Message(event) = event {
let response: ResponseStreamChunk = serde_json::from_str(&event.data).unwrap();
cloned_queue.push(Some(response));
}
}
cloned_queue.push(None);
});
// Return the queue that will receive the responses.
queue
}
pub async fn stream_generate_content( pub async fn stream_generate_content(
&self, &self,
request: &GenerateContentRequest, request: &GenerateContentRequest,
@@ -125,6 +167,12 @@ impl<T: TokenProvider + Clone> GeminiClient<T> {
for chunk in response { for chunk in response {
match chunk { match chunk {
ResponseStreamChunk::Ok(ok_response) => { ResponseStreamChunk::Ok(ok_response) => {
ok_response.candidates.iter().for_each(|c| {
if let Some(t) = c.get_text() {
text.push_str(&t);
}
});
for candidate in ok_response.candidates { for candidate in ok_response.candidates {
if let Some(parts) = &candidate.content.parts { if let Some(parts) = &candidate.content.parts {
for part in parts { for part in parts {

View File

@@ -13,24 +13,51 @@ pub struct CountTokensResponse {
pub total_tokens: i32, pub total_tokens: i32,
} }
#[derive(Serialize, Deserialize)] #[derive(Clone, Serialize, Deserialize)]
pub struct GenerateContentRequest { pub struct GenerateContentRequest {
pub contents: Vec<Content>, pub contents: Vec<Content>,
pub generation_config: Option<GenerationConfig>, pub generation_config: Option<GenerationConfig>,
pub tools: Option<Vec<Tools>>, pub tools: Option<Vec<Tools>>,
} }
#[derive(Serialize, Deserialize)] impl GenerateContentRequest {
pub fn from_prompt(prompt: &str, generation_config: Option<GenerationConfig>) -> Self {
GenerateContentRequest {
contents: vec![Content {
role: "user".to_string(),
parts: Some(vec![Part::Text(prompt.to_string())]),
}],
generation_config,
tools: None,
}
}
}
#[derive(Clone, Serialize, Deserialize)]
pub struct Tools { pub struct Tools {
pub function_declarations: Option<Vec<FunctionDeclaration>>, pub function_declarations: Option<Vec<FunctionDeclaration>>,
} }
#[derive(Debug, Serialize, Deserialize)] #[derive(Clone, Debug, Serialize, Deserialize)]
pub struct Content { pub struct Content {
pub role: String, pub role: String,
pub parts: Option<Vec<Part>>, pub parts: Option<Vec<Part>>,
} }
impl Content {
pub fn get_text(&self) -> Option<String> {
self.parts.as_ref().map(|parts| {
parts
.iter()
.filter_map(|part| match part {
Part::Text(text) => Some(text.clone()),
_ => None,
})
.collect::<String>()
})
}
}
#[derive(Clone, Debug, Serialize, Deserialize, Default)] #[derive(Clone, Debug, Serialize, Deserialize, Default)]
#[serde(rename_all = "camelCase")] #[serde(rename_all = "camelCase")]
pub struct GenerationConfig { pub struct GenerationConfig {
@@ -42,7 +69,7 @@ pub struct GenerationConfig {
pub candidate_count: Option<u32>, pub candidate_count: Option<u32>,
} }
#[derive(Debug, Serialize, Deserialize)] #[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")] #[serde(rename_all = "camelCase")]
pub enum Part { pub enum Part {
Text(String), Text(String),
@@ -86,6 +113,12 @@ pub struct Candidate {
pub finish_reason: Option<String>, pub finish_reason: Option<String>,
} }
impl Candidate {
pub fn get_text(&self) -> Option<String> {
self.content.get_text()
}
}
#[derive(Debug, Serialize, Deserialize)] #[derive(Debug, Serialize, Deserialize)]
pub struct SafetyRating { pub struct SafetyRating {
pub category: String, pub category: String,
@@ -113,7 +146,7 @@ pub struct UsageMetadata {
pub total_token_count: i32, pub total_token_count: i32,
} }
#[derive(Debug, Serialize, Deserialize)] #[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")] #[serde(rename_all = "camelCase")]
pub struct FunctionDeclaration { pub struct FunctionDeclaration {
pub name: String, pub name: String,
@@ -121,7 +154,7 @@ pub struct FunctionDeclaration {
pub parameters: FunctionParameters, pub parameters: FunctionParameters,
} }
#[derive(Debug, Serialize, Deserialize)] #[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")] #[serde(rename_all = "camelCase")]
pub struct FunctionParameters { pub struct FunctionParameters {
pub r#type: String, pub r#type: String,
@@ -129,7 +162,7 @@ pub struct FunctionParameters {
pub required: Vec<String>, pub required: Vec<String>,
} }
#[derive(Debug, Serialize, Deserialize)] #[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")] #[serde(rename_all = "camelCase")]
pub struct FunctionParametersProperty { pub struct FunctionParametersProperty {
pub r#type: String, pub r#type: String,