diff --git a/examples/conversation.rs b/examples/conversation.rs index 3b4af20..9ccf0b0 100644 --- a/examples/conversation.rs +++ b/examples/conversation.rs @@ -16,7 +16,7 @@ async fn main() -> Result<(), Box> { tracing_subscriber::fmt().init(); - let vertex_client = VertexClient::new( + let gemini = GeminiClient::new( authentication_manager, api_endpoint, project_id, @@ -25,7 +25,7 @@ async fn main() -> Result<(), Box> { tracing::info!("Starting conversation..."); - let mut conversation = Conversation::new(); + let mut conversation = Dialogue::new(); loop { let message: String = Input::with_theme(&ColorfulTheme::default()) .with_prompt("user") @@ -36,9 +36,6 @@ async fn main() -> Result<(), Box> { break; } - // Push the user's message to the conversation. - conversation.push_message(Message::new(Role::User, &message)); - // Show a spinner while the model is thinking. let progress = ProgressBar::new_spinner(); progress.enable_steady_tick(Duration::from_millis(120)); @@ -46,7 +43,7 @@ async fn main() -> Result<(), Box> { progress.set_message("Thinking..."); // Prompt the model with the conversation so far. - let response = vertex_client.prompt_conversation(&conversation).await?; + let response = conversation.do_turn(&gemini, &message).await?; // Stop the spinner and clear the terminal. progress.finish_and_clear(); @@ -58,7 +55,6 @@ async fn main() -> Result<(), Box> { style("ยท").dim(), style(&response.text).cyan() ); - conversation.push_message(response); } Ok(()) diff --git a/examples/text-from-text.rs b/examples/text-from-text.rs index b944482..ef7abf4 100644 --- a/examples/text-from-text.rs +++ b/examples/text-from-text.rs @@ -11,7 +11,7 @@ async fn main() -> Result<(), Box> { let project_id = std::env::var("PROJECT_ID")?; let location_id = std::env::var("LOCATION_ID")?; - let vertex_client = VertexClient::new( + let gemini = GeminiClient::new( authentication_manager, api_endpoint, project_id, @@ -19,7 +19,7 @@ async fn main() -> Result<(), Box> { ); let prompt = "What is the airspeed of an unladen swallow?"; - let result = vertex_client.prompt_text(prompt, None).await?; + let result = gemini.prompt_text(prompt, None).await?; println!("Response: {}", result); Ok(()) diff --git a/src/vertex_client.rs b/src/client.rs similarity index 82% rename from src/vertex_client.rs rename to src/client.rs index dc54405..683e83d 100644 --- a/src/vertex_client.rs +++ b/src/client.rs @@ -1,8 +1,7 @@ -use crate::conversation::{Message, Role}; +use crate::dialogue::{Message, Role}; use crate::error::{Error, Result}; use crate::prelude::{ - Content, Conversation, GenerateContentRequest, GenerateContentResponse, GenerationConfig, - ResponseStreamChunk, + Content, GenerateContentRequest, GenerateContentResponse, GenerationConfig, ResponseStreamChunk, }; use crate::{prelude::Part, token_provider::TokenProvider}; @@ -20,8 +19,8 @@ impl ToString for Model { } } -#[derive(Clone)] -pub struct VertexClient { +#[derive(Clone, Debug)] +pub struct GeminiClient { token_provider: T, client: reqwest::Client, api_endpoint: String, @@ -29,17 +28,17 @@ pub struct VertexClient { location_id: String, } -unsafe impl Send for VertexClient {} -unsafe impl Sync for VertexClient {} +unsafe impl Send for GeminiClient {} +unsafe impl Sync for GeminiClient {} -impl VertexClient { +impl GeminiClient { pub fn new( token_provider: T, api_endpoint: String, project_id: String, location_id: String, ) -> Self { - VertexClient { + GeminiClient { token_provider, client: reqwest::Client::new(), api_endpoint, @@ -76,10 +75,9 @@ impl VertexClient { } /// Prompts a conversation to the model. - pub async fn prompt_conversation(&self, conversation: &Conversation) -> Result { + pub async fn prompt_conversation(&self, messages: &[Message]) -> Result { let request = GenerateContentRequest { - contents: conversation - .messages + contents: messages .iter() .map(|m| Content { role: m.role.to_string(), @@ -95,7 +93,7 @@ impl VertexClient { .await?; // Check for errors in the response. - let text = VertexClient::::collect_text_from_response(response)?; + let text = GeminiClient::::collect_text_from_response(response)?; Ok(Message::new(Role::Model, &text)) } @@ -119,7 +117,7 @@ impl VertexClient { .stream_generate_content(&request, Model::GeminiPro) .await?; - VertexClient::::collect_text_from_response(response) + GeminiClient::::collect_text_from_response(response) } fn collect_text_from_response(response: GenerateContentResponse) -> Result { diff --git a/src/conversation.rs b/src/dialogue.rs similarity index 54% rename from src/conversation.rs rename to src/dialogue.rs index 0a8c6c2..54f5048 100644 --- a/src/conversation.rs +++ b/src/dialogue.rs @@ -2,6 +2,8 @@ use std::str::FromStr; use serde::{Deserialize, Serialize}; +use crate::{client::GeminiClient, error::Result, prelude::TokenProvider}; + #[derive(Clone, Debug, Serialize, Deserialize)] pub enum Role { User, @@ -20,7 +22,7 @@ impl ToString for Role { impl FromStr for Role { type Err = (); - fn from_str(s: &str) -> Result { + fn from_str(s: &str) -> std::result::Result { match s { "user" => Ok(Role::User), "model" => Ok(Role::Model), @@ -44,17 +46,24 @@ impl Message { } } -#[derive(Clone, Debug, Serialize, Deserialize)] -pub struct Conversation { - pub messages: Vec, +#[derive(Clone, Debug)] +pub struct Dialogue { + messages: Vec, } -impl Conversation { +impl Dialogue { pub fn new() -> Self { - Conversation { messages: vec![] } + Dialogue { messages: vec![] } } - pub fn push_message(&mut self, message: Message) { - self.messages.push(message); + pub async fn do_turn( + &mut self, + gemini: &GeminiClient, + message: &str, + ) -> Result { + self.messages.push(Message::new(Role::User, message)); + let response = gemini.prompt_conversation(&self.messages).await?; + self.messages.push(response.clone()); + Ok(response) } } diff --git a/src/lib.rs b/src/lib.rs index d95dbe2..25b517c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,12 +1,12 @@ -mod conversation; +mod client; +mod dialogue; pub mod error; mod token_provider; mod types; -mod vertex_client; pub mod prelude { - pub use crate::conversation::*; + pub use crate::client::*; + pub use crate::dialogue::*; pub use crate::token_provider::*; pub use crate::types::*; - pub use crate::vertex_client::*; }