Rename client and refactors dialog

This commit is contained in:
2024-02-07 20:19:11 +00:00
parent 4f49ec8478
commit 5ab9f353bc
5 changed files with 38 additions and 35 deletions

View File

@@ -16,7 +16,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
tracing_subscriber::fmt().init(); tracing_subscriber::fmt().init();
let vertex_client = VertexClient::new( let gemini = GeminiClient::new(
authentication_manager, authentication_manager,
api_endpoint, api_endpoint,
project_id, project_id,
@@ -25,7 +25,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
tracing::info!("Starting conversation..."); tracing::info!("Starting conversation...");
let mut conversation = Conversation::new(); let mut conversation = Dialogue::new();
loop { loop {
let message: String = Input::with_theme(&ColorfulTheme::default()) let message: String = Input::with_theme(&ColorfulTheme::default())
.with_prompt("user") .with_prompt("user")
@@ -36,9 +36,6 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
break; break;
} }
// Push the user's message to the conversation.
conversation.push_message(Message::new(Role::User, &message));
// Show a spinner while the model is thinking. // Show a spinner while the model is thinking.
let progress = ProgressBar::new_spinner(); let progress = ProgressBar::new_spinner();
progress.enable_steady_tick(Duration::from_millis(120)); progress.enable_steady_tick(Duration::from_millis(120));
@@ -46,7 +43,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
progress.set_message("Thinking..."); progress.set_message("Thinking...");
// Prompt the model with the conversation so far. // Prompt the model with the conversation so far.
let response = vertex_client.prompt_conversation(&conversation).await?; let response = conversation.do_turn(&gemini, &message).await?;
// Stop the spinner and clear the terminal. // Stop the spinner and clear the terminal.
progress.finish_and_clear(); progress.finish_and_clear();
@@ -58,7 +55,6 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
style("·").dim(), style("·").dim(),
style(&response.text).cyan() style(&response.text).cyan()
); );
conversation.push_message(response);
} }
Ok(()) Ok(())

View File

@@ -11,7 +11,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
let project_id = std::env::var("PROJECT_ID")?; let project_id = std::env::var("PROJECT_ID")?;
let location_id = std::env::var("LOCATION_ID")?; let location_id = std::env::var("LOCATION_ID")?;
let vertex_client = VertexClient::new( let gemini = GeminiClient::new(
authentication_manager, authentication_manager,
api_endpoint, api_endpoint,
project_id, project_id,
@@ -19,7 +19,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
); );
let prompt = "What is the airspeed of an unladen swallow?"; let prompt = "What is the airspeed of an unladen swallow?";
let result = vertex_client.prompt_text(prompt, None).await?; let result = gemini.prompt_text(prompt, None).await?;
println!("Response: {}", result); println!("Response: {}", result);
Ok(()) Ok(())

View File

@@ -1,8 +1,7 @@
use crate::conversation::{Message, Role}; use crate::dialogue::{Message, Role};
use crate::error::{Error, Result}; use crate::error::{Error, Result};
use crate::prelude::{ use crate::prelude::{
Content, Conversation, GenerateContentRequest, GenerateContentResponse, GenerationConfig, Content, GenerateContentRequest, GenerateContentResponse, GenerationConfig, ResponseStreamChunk,
ResponseStreamChunk,
}; };
use crate::{prelude::Part, token_provider::TokenProvider}; use crate::{prelude::Part, token_provider::TokenProvider};
@@ -20,8 +19,8 @@ impl ToString for Model {
} }
} }
#[derive(Clone)] #[derive(Clone, Debug)]
pub struct VertexClient<T: TokenProvider + Clone> { pub struct GeminiClient<T: TokenProvider + Clone> {
token_provider: T, token_provider: T,
client: reqwest::Client, client: reqwest::Client,
api_endpoint: String, api_endpoint: String,
@@ -29,17 +28,17 @@ pub struct VertexClient<T: TokenProvider + Clone> {
location_id: String, location_id: String,
} }
unsafe impl<T: TokenProvider + Clone> Send for VertexClient<T> {} unsafe impl<T: TokenProvider + Clone> Send for GeminiClient<T> {}
unsafe impl<T: TokenProvider + Clone> Sync for VertexClient<T> {} unsafe impl<T: TokenProvider + Clone> Sync for GeminiClient<T> {}
impl<T: TokenProvider + Clone> VertexClient<T> { impl<T: TokenProvider + Clone> GeminiClient<T> {
pub fn new( pub fn new(
token_provider: T, token_provider: T,
api_endpoint: String, api_endpoint: String,
project_id: String, project_id: String,
location_id: String, location_id: String,
) -> Self { ) -> Self {
VertexClient { GeminiClient {
token_provider, token_provider,
client: reqwest::Client::new(), client: reqwest::Client::new(),
api_endpoint, api_endpoint,
@@ -76,10 +75,9 @@ impl<T: TokenProvider + Clone> VertexClient<T> {
} }
/// Prompts a conversation to the model. /// Prompts a conversation to the model.
pub async fn prompt_conversation(&self, conversation: &Conversation) -> Result<Message> { pub async fn prompt_conversation(&self, messages: &[Message]) -> Result<Message> {
let request = GenerateContentRequest { let request = GenerateContentRequest {
contents: conversation contents: messages
.messages
.iter() .iter()
.map(|m| Content { .map(|m| Content {
role: m.role.to_string(), role: m.role.to_string(),
@@ -95,7 +93,7 @@ impl<T: TokenProvider + Clone> VertexClient<T> {
.await?; .await?;
// Check for errors in the response. // Check for errors in the response.
let text = VertexClient::<T>::collect_text_from_response(response)?; let text = GeminiClient::<T>::collect_text_from_response(response)?;
Ok(Message::new(Role::Model, &text)) Ok(Message::new(Role::Model, &text))
} }
@@ -119,7 +117,7 @@ impl<T: TokenProvider + Clone> VertexClient<T> {
.stream_generate_content(&request, Model::GeminiPro) .stream_generate_content(&request, Model::GeminiPro)
.await?; .await?;
VertexClient::<T>::collect_text_from_response(response) GeminiClient::<T>::collect_text_from_response(response)
} }
fn collect_text_from_response(response: GenerateContentResponse) -> Result<String> { fn collect_text_from_response(response: GenerateContentResponse) -> Result<String> {

View File

@@ -2,6 +2,8 @@ use std::str::FromStr;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use crate::{client::GeminiClient, error::Result, prelude::TokenProvider};
#[derive(Clone, Debug, Serialize, Deserialize)] #[derive(Clone, Debug, Serialize, Deserialize)]
pub enum Role { pub enum Role {
User, User,
@@ -20,7 +22,7 @@ impl ToString for Role {
impl FromStr for Role { impl FromStr for Role {
type Err = (); type Err = ();
fn from_str(s: &str) -> Result<Self, Self::Err> { fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
match s { match s {
"user" => Ok(Role::User), "user" => Ok(Role::User),
"model" => Ok(Role::Model), "model" => Ok(Role::Model),
@@ -44,17 +46,24 @@ impl Message {
} }
} }
#[derive(Clone, Debug, Serialize, Deserialize)] #[derive(Clone, Debug)]
pub struct Conversation { pub struct Dialogue {
pub messages: Vec<Message>, messages: Vec<Message>,
} }
impl Conversation { impl Dialogue {
pub fn new() -> Self { pub fn new() -> Self {
Conversation { messages: vec![] } Dialogue { messages: vec![] }
} }
pub fn push_message(&mut self, message: Message) { pub async fn do_turn<T: TokenProvider + Clone>(
self.messages.push(message); &mut self,
gemini: &GeminiClient<T>,
message: &str,
) -> Result<Message> {
self.messages.push(Message::new(Role::User, message));
let response = gemini.prompt_conversation(&self.messages).await?;
self.messages.push(response.clone());
Ok(response)
} }
} }

View File

@@ -1,12 +1,12 @@
mod conversation; mod client;
mod dialogue;
pub mod error; pub mod error;
mod token_provider; mod token_provider;
mod types; mod types;
mod vertex_client;
pub mod prelude { pub mod prelude {
pub use crate::conversation::*; pub use crate::client::*;
pub use crate::dialogue::*;
pub use crate::token_provider::*; pub use crate::token_provider::*;
pub use crate::types::*; pub use crate::types::*;
pub use crate::vertex_client::*;
} }