Adds streaming content
This commit is contained in:
@@ -6,11 +6,15 @@ edition = "2021"
|
|||||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
|
deadqueue = "0.2.4"
|
||||||
|
futures-util = "0.3.30"
|
||||||
gcp_auth = "0.10.0"
|
gcp_auth = "0.10.0"
|
||||||
reqwest = { version = "0.11.9", features = ["json", "gzip"] }
|
reqwest = { version = "0.11.9", features = ["json", "gzip"] }
|
||||||
|
reqwest-eventsource = "0.5.0"
|
||||||
serde = { version = "*", features = ["derive"] }
|
serde = { version = "*", features = ["derive"] }
|
||||||
serde_json = { version = "*"}
|
serde_json = { version = "*"}
|
||||||
tracing = "0.1.40"
|
tracing = "0.1.40"
|
||||||
|
tokio = { version = "1.36.0" }
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
console = "0.15.8"
|
console = "0.15.8"
|
||||||
|
|||||||
39
examples/text-from-text-streaming.rs
Normal file
39
examples/text-from-text-streaming.rs
Normal file
@@ -0,0 +1,39 @@
|
|||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
use gemini_rs::prelude::*;
|
||||||
|
|
||||||
|
use gcp_auth::AuthenticationManager;
|
||||||
|
|
||||||
|
#[tokio::main]
|
||||||
|
async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||||
|
let authentication_manager = Arc::new(AuthenticationManager::new().await?);
|
||||||
|
let api_endpoint = std::env::var("API_ENDPOINT")?;
|
||||||
|
let project_id = std::env::var("PROJECT_ID")?;
|
||||||
|
let location_id = std::env::var("LOCATION_ID")?;
|
||||||
|
|
||||||
|
let gemini = GeminiClient::new(
|
||||||
|
authentication_manager,
|
||||||
|
api_endpoint,
|
||||||
|
project_id,
|
||||||
|
location_id,
|
||||||
|
);
|
||||||
|
|
||||||
|
let prompt = "Tell me the story of the genesis of the universe as a bedtime story.";
|
||||||
|
let request = GenerateContentRequest::from_prompt(prompt, None);
|
||||||
|
let queue = gemini
|
||||||
|
.streaming_stream_generate_content(&request, Model::GeminiPro)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
while let Some(chunk) = queue.pop().await {
|
||||||
|
if let ResponseStreamChunk::Ok(ok_response) = chunk {
|
||||||
|
let text = ok_response
|
||||||
|
.candidates
|
||||||
|
.iter()
|
||||||
|
.filter_map(|c| c.get_text())
|
||||||
|
.collect::<String>();
|
||||||
|
print!("{}", text);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
@@ -1,3 +1,9 @@
|
|||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
use deadqueue::unlimited::Queue;
|
||||||
|
use futures_util::stream::StreamExt;
|
||||||
|
use reqwest_eventsource::{Event, EventSource};
|
||||||
|
|
||||||
use crate::dialogue::{Message, Role};
|
use crate::dialogue::{Message, Role};
|
||||||
use crate::error::{Error, Result};
|
use crate::error::{Error, Result};
|
||||||
use crate::prelude::{
|
use crate::prelude::{
|
||||||
@@ -47,6 +53,42 @@ impl<T: TokenProvider + Clone> GeminiClient<T> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub async fn streaming_stream_generate_content(
|
||||||
|
&self,
|
||||||
|
request: &GenerateContentRequest,
|
||||||
|
model: Model,
|
||||||
|
) -> Arc<Queue<Option<ResponseStreamChunk>>> {
|
||||||
|
let queue = Arc::new(Queue::<Option<ResponseStreamChunk>>::new());
|
||||||
|
|
||||||
|
// Clone the queue and other necessary data to move into the async block.
|
||||||
|
let cloned_queue = queue.clone();
|
||||||
|
let access_token = self.token_provider.get_token(AUTH_SCOPE).await.unwrap();
|
||||||
|
let endpoint_url: String = format!(
|
||||||
|
"https://{}/v1beta1/projects/{}/locations/{}/publishers/google/models/{}:streamGenerateContent?alt=sse", self.api_endpoint, self.project_id, self.location_id, model.to_string(),
|
||||||
|
);
|
||||||
|
let client = self.client.clone();
|
||||||
|
let request = request.clone();
|
||||||
|
|
||||||
|
// Start a thread to run the request in the background.
|
||||||
|
tokio::spawn(async move {
|
||||||
|
let req = client
|
||||||
|
.post(&endpoint_url)
|
||||||
|
.bearer_auth(access_token)
|
||||||
|
.json(&request);
|
||||||
|
let mut event_source = EventSource::new(req).unwrap();
|
||||||
|
while let Some(Ok(event)) = event_source.next().await {
|
||||||
|
if let Event::Message(event) = event {
|
||||||
|
let response: ResponseStreamChunk = serde_json::from_str(&event.data).unwrap();
|
||||||
|
cloned_queue.push(Some(response));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
cloned_queue.push(None);
|
||||||
|
});
|
||||||
|
|
||||||
|
// Return the queue that will receive the responses.
|
||||||
|
queue
|
||||||
|
}
|
||||||
|
|
||||||
pub async fn stream_generate_content(
|
pub async fn stream_generate_content(
|
||||||
&self,
|
&self,
|
||||||
request: &GenerateContentRequest,
|
request: &GenerateContentRequest,
|
||||||
@@ -125,6 +167,12 @@ impl<T: TokenProvider + Clone> GeminiClient<T> {
|
|||||||
for chunk in response {
|
for chunk in response {
|
||||||
match chunk {
|
match chunk {
|
||||||
ResponseStreamChunk::Ok(ok_response) => {
|
ResponseStreamChunk::Ok(ok_response) => {
|
||||||
|
ok_response.candidates.iter().for_each(|c| {
|
||||||
|
if let Some(t) = c.get_text() {
|
||||||
|
text.push_str(&t);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
for candidate in ok_response.candidates {
|
for candidate in ok_response.candidates {
|
||||||
if let Some(parts) = &candidate.content.parts {
|
if let Some(parts) = &candidate.content.parts {
|
||||||
for part in parts {
|
for part in parts {
|
||||||
|
|||||||
47
src/types.rs
47
src/types.rs
@@ -13,24 +13,51 @@ pub struct CountTokensResponse {
|
|||||||
pub total_tokens: i32,
|
pub total_tokens: i32,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize)]
|
#[derive(Clone, Serialize, Deserialize)]
|
||||||
pub struct GenerateContentRequest {
|
pub struct GenerateContentRequest {
|
||||||
pub contents: Vec<Content>,
|
pub contents: Vec<Content>,
|
||||||
pub generation_config: Option<GenerationConfig>,
|
pub generation_config: Option<GenerationConfig>,
|
||||||
pub tools: Option<Vec<Tools>>,
|
pub tools: Option<Vec<Tools>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize)]
|
impl GenerateContentRequest {
|
||||||
|
pub fn from_prompt(prompt: &str, generation_config: Option<GenerationConfig>) -> Self {
|
||||||
|
GenerateContentRequest {
|
||||||
|
contents: vec![Content {
|
||||||
|
role: "user".to_string(),
|
||||||
|
parts: Some(vec![Part::Text(prompt.to_string())]),
|
||||||
|
}],
|
||||||
|
generation_config,
|
||||||
|
tools: None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Serialize, Deserialize)]
|
||||||
pub struct Tools {
|
pub struct Tools {
|
||||||
pub function_declarations: Option<Vec<FunctionDeclaration>>,
|
pub function_declarations: Option<Vec<FunctionDeclaration>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Serialize, Deserialize)]
|
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||||
pub struct Content {
|
pub struct Content {
|
||||||
pub role: String,
|
pub role: String,
|
||||||
pub parts: Option<Vec<Part>>,
|
pub parts: Option<Vec<Part>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl Content {
|
||||||
|
pub fn get_text(&self) -> Option<String> {
|
||||||
|
self.parts.as_ref().map(|parts| {
|
||||||
|
parts
|
||||||
|
.iter()
|
||||||
|
.filter_map(|part| match part {
|
||||||
|
Part::Text(text) => Some(text.clone()),
|
||||||
|
_ => None,
|
||||||
|
})
|
||||||
|
.collect::<String>()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Clone, Debug, Serialize, Deserialize, Default)]
|
#[derive(Clone, Debug, Serialize, Deserialize, Default)]
|
||||||
#[serde(rename_all = "camelCase")]
|
#[serde(rename_all = "camelCase")]
|
||||||
pub struct GenerationConfig {
|
pub struct GenerationConfig {
|
||||||
@@ -42,7 +69,7 @@ pub struct GenerationConfig {
|
|||||||
pub candidate_count: Option<u32>,
|
pub candidate_count: Option<u32>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Serialize, Deserialize)]
|
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||||
#[serde(rename_all = "camelCase")]
|
#[serde(rename_all = "camelCase")]
|
||||||
pub enum Part {
|
pub enum Part {
|
||||||
Text(String),
|
Text(String),
|
||||||
@@ -86,6 +113,12 @@ pub struct Candidate {
|
|||||||
pub finish_reason: Option<String>,
|
pub finish_reason: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl Candidate {
|
||||||
|
pub fn get_text(&self) -> Option<String> {
|
||||||
|
self.content.get_text()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug, Serialize, Deserialize)]
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
pub struct SafetyRating {
|
pub struct SafetyRating {
|
||||||
pub category: String,
|
pub category: String,
|
||||||
@@ -113,7 +146,7 @@ pub struct UsageMetadata {
|
|||||||
pub total_token_count: i32,
|
pub total_token_count: i32,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Serialize, Deserialize)]
|
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||||
#[serde(rename_all = "camelCase")]
|
#[serde(rename_all = "camelCase")]
|
||||||
pub struct FunctionDeclaration {
|
pub struct FunctionDeclaration {
|
||||||
pub name: String,
|
pub name: String,
|
||||||
@@ -121,7 +154,7 @@ pub struct FunctionDeclaration {
|
|||||||
pub parameters: FunctionParameters,
|
pub parameters: FunctionParameters,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Serialize, Deserialize)]
|
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||||
#[serde(rename_all = "camelCase")]
|
#[serde(rename_all = "camelCase")]
|
||||||
pub struct FunctionParameters {
|
pub struct FunctionParameters {
|
||||||
pub r#type: String,
|
pub r#type: String,
|
||||||
@@ -129,7 +162,7 @@ pub struct FunctionParameters {
|
|||||||
pub required: Vec<String>,
|
pub required: Vec<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Serialize, Deserialize)]
|
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||||
#[serde(rename_all = "camelCase")]
|
#[serde(rename_all = "camelCase")]
|
||||||
pub struct FunctionParametersProperty {
|
pub struct FunctionParametersProperty {
|
||||||
pub r#type: String,
|
pub r#type: String,
|
||||||
|
|||||||
Reference in New Issue
Block a user