Allow user to choose a conversation model

This commit is contained in:
2024-04-18 15:36:26 +01:00
parent 8b94651a84
commit 3d298fabbc
3 changed files with 12 additions and 6 deletions

View File

@@ -25,7 +25,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
tracing::info!("Starting conversation..."); tracing::info!("Starting conversation...");
let mut conversation = Dialogue::new(); let mut conversation = Dialogue::new("gemini-pro");
loop { loop {
let message: String = Input::with_theme(&ColorfulTheme::default()) let message: String = Input::with_theme(&ColorfulTheme::default())
.with_prompt("user") .with_prompt("user")

View File

@@ -116,7 +116,7 @@ impl<T: TokenProvider + Clone> GeminiClient<T> {
} }
/// Prompts a conversation to the model. /// Prompts a conversation to the model.
pub async fn prompt_conversation(&self, messages: &[Message]) -> Result<Message> { pub async fn prompt_conversation(&self, messages: &[Message], model: &str) -> Result<Message> {
let request = GenerateContentRequest { let request = GenerateContentRequest {
contents: messages contents: messages
.iter() .iter()
@@ -129,7 +129,7 @@ impl<T: TokenProvider + Clone> GeminiClient<T> {
tools: None, tools: None,
}; };
let response = self.generate_content(&request, "gemini-pro").await?; let response = self.generate_content(&request, model).await?;
// Check for errors in the response. // Check for errors in the response.
let mut candidates = GeminiClient::<T>::collect_text_from_response(&response)?; let mut candidates = GeminiClient::<T>::collect_text_from_response(&response)?;

View File

@@ -48,12 +48,16 @@ impl Message {
#[derive(Clone, Debug, Serialize, Deserialize)] #[derive(Clone, Debug, Serialize, Deserialize)]
pub struct Dialogue { pub struct Dialogue {
model: String,
messages: Vec<Message>, messages: Vec<Message>,
} }
impl Dialogue { impl Dialogue {
pub fn new() -> Self { pub fn new(model: &str) -> Self {
Dialogue { messages: vec![] } Dialogue {
model: model.to_string(),
messages: vec![],
}
} }
pub async fn do_turn<T: TokenProvider + Clone>( pub async fn do_turn<T: TokenProvider + Clone>(
@@ -62,7 +66,9 @@ impl Dialogue {
message: &str, message: &str,
) -> Result<Message> { ) -> Result<Message> {
self.messages.push(Message::new(Role::User, message)); self.messages.push(Message::new(Role::User, message));
let response = gemini.prompt_conversation(&self.messages).await?; let response = gemini
.prompt_conversation(&self.messages, &self.model)
.await?;
self.messages.push(response.clone()); self.messages.push(response.clone());
Ok(response) Ok(response)
} }