diff --git a/Cargo.toml b/Cargo.toml index e793ec7..1aee212 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -8,7 +8,7 @@ edition = "2021" [dependencies] deadqueue = "0.2" futures-util = "0.3" -gcp_auth = "0.11" +gcp_auth = "0.12" reqwest = { version = "0.12", features = ["json", "gzip"] } reqwest-eventsource = "0.6" serde = { version = "1", features = ["derive"] } diff --git a/examples/conversation.rs b/examples/conversation.rs index a4657db..4707a71 100644 --- a/examples/conversation.rs +++ b/examples/conversation.rs @@ -1,61 +1,59 @@ -use std::{sync::Arc, time::Duration}; - -use console::style; -use dialoguer::{theme::ColorfulTheme, Input}; -use gemini_rs::prelude::*; - -use gcp_auth::AuthenticationManager; -use indicatif::{ProgressBar, ProgressStyle}; - -#[tokio::main] -async fn main() -> Result<(), Box> { - let authentication_manager = Arc::new(AuthenticationManager::new().await?); - let api_endpoint = std::env::var("API_ENDPOINT")?; - let project_id = std::env::var("PROJECT_ID")?; - let location_id = std::env::var("LOCATION_ID")?; - - tracing_subscriber::fmt().init(); - - let gemini = GeminiClient::new( - authentication_manager, - api_endpoint, - project_id, - location_id, - ); - - tracing::info!("Starting conversation..."); - - let mut conversation = Dialogue::new("gemini-pro"); - loop { - let message: String = Input::with_theme(&ColorfulTheme::default()) - .with_prompt("user") - .interact_text()?; - - // Exit the conversation if the user types "exit" - if message == "exit" { - break; - } - - // Show a spinner while the model is thinking. - let progress = ProgressBar::new_spinner(); - progress.enable_steady_tick(Duration::from_millis(120)); - progress.set_style(ProgressStyle::with_template("{spinner:.green} {msg}")?); - progress.set_message("Thinking..."); - - // Prompt the model with the conversation so far. - let response = conversation.do_turn(&gemini, &message).await?; - - // Stop the spinner and clear the terminal. - progress.finish_and_clear(); - - // Print the model's response. - println!( - "✨ {} {} {}", - style(response.role.to_string()).bold(), - style("·").dim(), - style(&response.text).cyan() - ); - } - - Ok(()) -} +use console::style; +use dialoguer::{theme::ColorfulTheme, Input}; +use gemini_rs::prelude::*; +use std::time::Duration; + +use indicatif::{ProgressBar, ProgressStyle}; + +#[tokio::main] +async fn main() -> Result<(), Box> { + let authentication_manager = gcp_auth::provider().await?; + let api_endpoint = std::env::var("API_ENDPOINT")?; + let project_id = std::env::var("PROJECT_ID")?; + let location_id = std::env::var("LOCATION_ID")?; + + tracing_subscriber::fmt().init(); + + let gemini = GeminiClient::new( + authentication_manager, + api_endpoint, + project_id, + location_id, + ); + + tracing::info!("Starting conversation..."); + + let mut conversation = Dialogue::new("gemini-pro"); + loop { + let message: String = Input::with_theme(&ColorfulTheme::default()) + .with_prompt("user") + .interact_text()?; + + // Exit the conversation if the user types "exit" + if message == "exit" { + break; + } + + // Show a spinner while the model is thinking. + let progress = ProgressBar::new_spinner(); + progress.enable_steady_tick(Duration::from_millis(120)); + progress.set_style(ProgressStyle::with_template("{spinner:.green} {msg}")?); + progress.set_message("Thinking..."); + + // Prompt the model with the conversation so far. + let response = conversation.do_turn(&gemini, &message).await?; + + // Stop the spinner and clear the terminal. + progress.finish_and_clear(); + + // Print the model's response. + println!( + "✨ {} {} {}", + style(response.role.to_string()).bold(), + style("·").dim(), + style(&response.text).cyan() + ); + } + + Ok(()) +} diff --git a/examples/count-tokens.rs b/examples/count-tokens.rs index 9169f54..d9d8d17 100644 --- a/examples/count-tokens.rs +++ b/examples/count-tokens.rs @@ -1,12 +1,8 @@ -use std::sync::Arc; - use gemini_rs::prelude::*; -use gcp_auth::AuthenticationManager; - #[tokio::main] async fn main() -> Result<(), Box> { - let authentication_manager = Arc::new(AuthenticationManager::new().await?); + let authentication_manager = gcp_auth::provider().await?; let api_endpoint = std::env::var("API_ENDPOINT")?; let project_id = std::env::var("PROJECT_ID")?; let location_id = std::env::var("LOCATION_ID")?; diff --git a/examples/system_instruction.rs b/examples/system_instruction.rs index f98931e..ec603f3 100644 --- a/examples/system_instruction.rs +++ b/examples/system_instruction.rs @@ -1,13 +1,9 @@ -use std::sync::Arc; - use gemini_rs::prelude::*; -use gcp_auth::AuthenticationManager; - #[tokio::main] async fn main() -> Result<(), Box> { tracing_subscriber::fmt().init(); - let authentication_manager = Arc::new(AuthenticationManager::new().await?); + let authentication_manager = gcp_auth::provider().await?; let api_endpoint = std::env::var("API_ENDPOINT")?; let project_id = std::env::var("PROJECT_ID")?; let location_id = std::env::var("LOCATION_ID")?; diff --git a/examples/text-embedding.rs b/examples/text-embedding.rs index 6b2f01e..94bcb53 100644 --- a/examples/text-embedding.rs +++ b/examples/text-embedding.rs @@ -1,12 +1,8 @@ -use std::sync::Arc; - use gemini_rs::prelude::*; -use gcp_auth::AuthenticationManager; - #[tokio::main] async fn main() -> Result<(), Box> { - let authentication_manager = Arc::new(AuthenticationManager::new().await?); + let authentication_manager = gcp_auth::provider().await?; let api_endpoint = std::env::var("API_ENDPOINT")?; let project_id = std::env::var("PROJECT_ID")?; let location_id = std::env::var("LOCATION_ID")?; diff --git a/examples/text-from-text-streaming.rs b/examples/text-from-text-streaming.rs index f90af7e..adf0edb 100644 --- a/examples/text-from-text-streaming.rs +++ b/examples/text-from-text-streaming.rs @@ -1,12 +1,8 @@ -use std::sync::Arc; - use gemini_rs::prelude::*; -use gcp_auth::AuthenticationManager; - #[tokio::main] async fn main() -> Result<(), Box> { - let authentication_manager = Arc::new(AuthenticationManager::new().await?); + let authentication_manager = gcp_auth::provider().await?; let api_endpoint = std::env::var("API_ENDPOINT")?; let project_id = std::env::var("PROJECT_ID")?; let location_id = std::env::var("LOCATION_ID")?; diff --git a/examples/text-from-text.rs b/examples/text-from-text.rs index ef7abf4..f69790f 100644 --- a/examples/text-from-text.rs +++ b/examples/text-from-text.rs @@ -1,12 +1,8 @@ -use std::sync::Arc; - use gemini_rs::prelude::*; -use gcp_auth::AuthenticationManager; - #[tokio::main] async fn main() -> Result<(), Box> { - let authentication_manager = Arc::new(AuthenticationManager::new().await?); + let authentication_manager = gcp_auth::provider().await?; let api_endpoint = std::env::var("API_ENDPOINT")?; let project_id = std::env::var("PROJECT_ID")?; let location_id = std::env::var("LOCATION_ID")?; diff --git a/src/token_provider.rs b/src/token_provider.rs index d506ed3..f79284f 100644 --- a/src/token_provider.rs +++ b/src/token_provider.rs @@ -1,7 +1,5 @@ use std::sync::Arc; -use gcp_auth::AuthenticationManager; - use crate::error::Result; pub trait TokenProvider { @@ -9,18 +7,10 @@ pub trait TokenProvider { -> impl std::future::Future> + Send; } -impl TokenProvider for Arc { +impl TokenProvider for Arc { async fn get_token(&self, scope: &[&str]) -> Result { - match AuthenticationManager::get_token(self, scope).await { - Ok(token) => Ok(token.as_str().to_string()), - Err(e) => Err(e.into()), - } - } -} - -impl TokenProvider for AuthenticationManager { - async fn get_token(&self, scope: &[&str]) -> Result { - match AuthenticationManager::get_token(self, scope).await { + let token = self.token(scope).await; + match token { Ok(token) => Ok(token.as_str().to_string()), Err(e) => Err(e.into()), }