From 3d298fabbca02e7aaa6d0ffd56dc32fa52487c27 Mon Sep 17 00:00:00 2001 From: Andre Bandarra Date: Thu, 18 Apr 2024 15:36:26 +0100 Subject: [PATCH] Allow user to choose a conversation model --- examples/conversation.rs | 2 +- src/client.rs | 4 ++-- src/dialogue.rs | 12 +++++++++--- 3 files changed, 12 insertions(+), 6 deletions(-) diff --git a/examples/conversation.rs b/examples/conversation.rs index 9ccf0b0..a4657db 100644 --- a/examples/conversation.rs +++ b/examples/conversation.rs @@ -25,7 +25,7 @@ async fn main() -> Result<(), Box> { tracing::info!("Starting conversation..."); - let mut conversation = Dialogue::new(); + let mut conversation = Dialogue::new("gemini-pro"); loop { let message: String = Input::with_theme(&ColorfulTheme::default()) .with_prompt("user") diff --git a/src/client.rs b/src/client.rs index 2e4b2c3..4728e0d 100644 --- a/src/client.rs +++ b/src/client.rs @@ -116,7 +116,7 @@ impl GeminiClient { } /// Prompts a conversation to the model. - pub async fn prompt_conversation(&self, messages: &[Message]) -> Result { + pub async fn prompt_conversation(&self, messages: &[Message], model: &str) -> Result { let request = GenerateContentRequest { contents: messages .iter() @@ -129,7 +129,7 @@ impl GeminiClient { tools: None, }; - let response = self.generate_content(&request, "gemini-pro").await?; + let response = self.generate_content(&request, model).await?; // Check for errors in the response. let mut candidates = GeminiClient::::collect_text_from_response(&response)?; diff --git a/src/dialogue.rs b/src/dialogue.rs index 157b36f..56d81b9 100644 --- a/src/dialogue.rs +++ b/src/dialogue.rs @@ -48,12 +48,16 @@ impl Message { #[derive(Clone, Debug, Serialize, Deserialize)] pub struct Dialogue { + model: String, messages: Vec, } impl Dialogue { - pub fn new() -> Self { - Dialogue { messages: vec![] } + pub fn new(model: &str) -> Self { + Dialogue { + model: model.to_string(), + messages: vec![], + } } pub async fn do_turn( @@ -62,7 +66,9 @@ impl Dialogue { message: &str, ) -> Result { self.messages.push(Message::new(Role::User, message)); - let response = gemini.prompt_conversation(&self.messages).await?; + let response = gemini + .prompt_conversation(&self.messages, &self.model) + .await?; self.messages.push(response.clone()); Ok(response) }