Rename client and refactors dialog
This commit is contained in:
@@ -16,7 +16,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
|||||||
|
|
||||||
tracing_subscriber::fmt().init();
|
tracing_subscriber::fmt().init();
|
||||||
|
|
||||||
let vertex_client = VertexClient::new(
|
let gemini = GeminiClient::new(
|
||||||
authentication_manager,
|
authentication_manager,
|
||||||
api_endpoint,
|
api_endpoint,
|
||||||
project_id,
|
project_id,
|
||||||
@@ -25,7 +25,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
|||||||
|
|
||||||
tracing::info!("Starting conversation...");
|
tracing::info!("Starting conversation...");
|
||||||
|
|
||||||
let mut conversation = Conversation::new();
|
let mut conversation = Dialogue::new();
|
||||||
loop {
|
loop {
|
||||||
let message: String = Input::with_theme(&ColorfulTheme::default())
|
let message: String = Input::with_theme(&ColorfulTheme::default())
|
||||||
.with_prompt("user")
|
.with_prompt("user")
|
||||||
@@ -36,9 +36,6 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
|||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Push the user's message to the conversation.
|
|
||||||
conversation.push_message(Message::new(Role::User, &message));
|
|
||||||
|
|
||||||
// Show a spinner while the model is thinking.
|
// Show a spinner while the model is thinking.
|
||||||
let progress = ProgressBar::new_spinner();
|
let progress = ProgressBar::new_spinner();
|
||||||
progress.enable_steady_tick(Duration::from_millis(120));
|
progress.enable_steady_tick(Duration::from_millis(120));
|
||||||
@@ -46,7 +43,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
|||||||
progress.set_message("Thinking...");
|
progress.set_message("Thinking...");
|
||||||
|
|
||||||
// Prompt the model with the conversation so far.
|
// Prompt the model with the conversation so far.
|
||||||
let response = vertex_client.prompt_conversation(&conversation).await?;
|
let response = conversation.do_turn(&gemini, &message).await?;
|
||||||
|
|
||||||
// Stop the spinner and clear the terminal.
|
// Stop the spinner and clear the terminal.
|
||||||
progress.finish_and_clear();
|
progress.finish_and_clear();
|
||||||
@@ -58,7 +55,6 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
|||||||
style("·").dim(),
|
style("·").dim(),
|
||||||
style(&response.text).cyan()
|
style(&response.text).cyan()
|
||||||
);
|
);
|
||||||
conversation.push_message(response);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
|||||||
let project_id = std::env::var("PROJECT_ID")?;
|
let project_id = std::env::var("PROJECT_ID")?;
|
||||||
let location_id = std::env::var("LOCATION_ID")?;
|
let location_id = std::env::var("LOCATION_ID")?;
|
||||||
|
|
||||||
let vertex_client = VertexClient::new(
|
let gemini = GeminiClient::new(
|
||||||
authentication_manager,
|
authentication_manager,
|
||||||
api_endpoint,
|
api_endpoint,
|
||||||
project_id,
|
project_id,
|
||||||
@@ -19,7 +19,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
|||||||
);
|
);
|
||||||
|
|
||||||
let prompt = "What is the airspeed of an unladen swallow?";
|
let prompt = "What is the airspeed of an unladen swallow?";
|
||||||
let result = vertex_client.prompt_text(prompt, None).await?;
|
let result = gemini.prompt_text(prompt, None).await?;
|
||||||
println!("Response: {}", result);
|
println!("Response: {}", result);
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
|
|||||||
@@ -1,8 +1,7 @@
|
|||||||
use crate::conversation::{Message, Role};
|
use crate::dialogue::{Message, Role};
|
||||||
use crate::error::{Error, Result};
|
use crate::error::{Error, Result};
|
||||||
use crate::prelude::{
|
use crate::prelude::{
|
||||||
Content, Conversation, GenerateContentRequest, GenerateContentResponse, GenerationConfig,
|
Content, GenerateContentRequest, GenerateContentResponse, GenerationConfig, ResponseStreamChunk,
|
||||||
ResponseStreamChunk,
|
|
||||||
};
|
};
|
||||||
use crate::{prelude::Part, token_provider::TokenProvider};
|
use crate::{prelude::Part, token_provider::TokenProvider};
|
||||||
|
|
||||||
@@ -20,8 +19,8 @@ impl ToString for Model {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone)]
|
#[derive(Clone, Debug)]
|
||||||
pub struct VertexClient<T: TokenProvider + Clone> {
|
pub struct GeminiClient<T: TokenProvider + Clone> {
|
||||||
token_provider: T,
|
token_provider: T,
|
||||||
client: reqwest::Client,
|
client: reqwest::Client,
|
||||||
api_endpoint: String,
|
api_endpoint: String,
|
||||||
@@ -29,17 +28,17 @@ pub struct VertexClient<T: TokenProvider + Clone> {
|
|||||||
location_id: String,
|
location_id: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
unsafe impl<T: TokenProvider + Clone> Send for VertexClient<T> {}
|
unsafe impl<T: TokenProvider + Clone> Send for GeminiClient<T> {}
|
||||||
unsafe impl<T: TokenProvider + Clone> Sync for VertexClient<T> {}
|
unsafe impl<T: TokenProvider + Clone> Sync for GeminiClient<T> {}
|
||||||
|
|
||||||
impl<T: TokenProvider + Clone> VertexClient<T> {
|
impl<T: TokenProvider + Clone> GeminiClient<T> {
|
||||||
pub fn new(
|
pub fn new(
|
||||||
token_provider: T,
|
token_provider: T,
|
||||||
api_endpoint: String,
|
api_endpoint: String,
|
||||||
project_id: String,
|
project_id: String,
|
||||||
location_id: String,
|
location_id: String,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
VertexClient {
|
GeminiClient {
|
||||||
token_provider,
|
token_provider,
|
||||||
client: reqwest::Client::new(),
|
client: reqwest::Client::new(),
|
||||||
api_endpoint,
|
api_endpoint,
|
||||||
@@ -76,10 +75,9 @@ impl<T: TokenProvider + Clone> VertexClient<T> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Prompts a conversation to the model.
|
/// Prompts a conversation to the model.
|
||||||
pub async fn prompt_conversation(&self, conversation: &Conversation) -> Result<Message> {
|
pub async fn prompt_conversation(&self, messages: &[Message]) -> Result<Message> {
|
||||||
let request = GenerateContentRequest {
|
let request = GenerateContentRequest {
|
||||||
contents: conversation
|
contents: messages
|
||||||
.messages
|
|
||||||
.iter()
|
.iter()
|
||||||
.map(|m| Content {
|
.map(|m| Content {
|
||||||
role: m.role.to_string(),
|
role: m.role.to_string(),
|
||||||
@@ -95,7 +93,7 @@ impl<T: TokenProvider + Clone> VertexClient<T> {
|
|||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
// Check for errors in the response.
|
// Check for errors in the response.
|
||||||
let text = VertexClient::<T>::collect_text_from_response(response)?;
|
let text = GeminiClient::<T>::collect_text_from_response(response)?;
|
||||||
Ok(Message::new(Role::Model, &text))
|
Ok(Message::new(Role::Model, &text))
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -119,7 +117,7 @@ impl<T: TokenProvider + Clone> VertexClient<T> {
|
|||||||
.stream_generate_content(&request, Model::GeminiPro)
|
.stream_generate_content(&request, Model::GeminiPro)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
VertexClient::<T>::collect_text_from_response(response)
|
GeminiClient::<T>::collect_text_from_response(response)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn collect_text_from_response(response: GenerateContentResponse) -> Result<String> {
|
fn collect_text_from_response(response: GenerateContentResponse) -> Result<String> {
|
||||||
@@ -2,6 +2,8 @@ use std::str::FromStr;
|
|||||||
|
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
|
use crate::{client::GeminiClient, error::Result, prelude::TokenProvider};
|
||||||
|
|
||||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||||
pub enum Role {
|
pub enum Role {
|
||||||
User,
|
User,
|
||||||
@@ -20,7 +22,7 @@ impl ToString for Role {
|
|||||||
impl FromStr for Role {
|
impl FromStr for Role {
|
||||||
type Err = ();
|
type Err = ();
|
||||||
|
|
||||||
fn from_str(s: &str) -> Result<Self, Self::Err> {
|
fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
|
||||||
match s {
|
match s {
|
||||||
"user" => Ok(Role::User),
|
"user" => Ok(Role::User),
|
||||||
"model" => Ok(Role::Model),
|
"model" => Ok(Role::Model),
|
||||||
@@ -44,17 +46,24 @@ impl Message {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
#[derive(Clone, Debug)]
|
||||||
pub struct Conversation {
|
pub struct Dialogue {
|
||||||
pub messages: Vec<Message>,
|
messages: Vec<Message>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Conversation {
|
impl Dialogue {
|
||||||
pub fn new() -> Self {
|
pub fn new() -> Self {
|
||||||
Conversation { messages: vec![] }
|
Dialogue { messages: vec![] }
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn push_message(&mut self, message: Message) {
|
pub async fn do_turn<T: TokenProvider + Clone>(
|
||||||
self.messages.push(message);
|
&mut self,
|
||||||
|
gemini: &GeminiClient<T>,
|
||||||
|
message: &str,
|
||||||
|
) -> Result<Message> {
|
||||||
|
self.messages.push(Message::new(Role::User, message));
|
||||||
|
let response = gemini.prompt_conversation(&self.messages).await?;
|
||||||
|
self.messages.push(response.clone());
|
||||||
|
Ok(response)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -1,12 +1,12 @@
|
|||||||
mod conversation;
|
mod client;
|
||||||
|
mod dialogue;
|
||||||
pub mod error;
|
pub mod error;
|
||||||
mod token_provider;
|
mod token_provider;
|
||||||
mod types;
|
mod types;
|
||||||
mod vertex_client;
|
|
||||||
|
|
||||||
pub mod prelude {
|
pub mod prelude {
|
||||||
pub use crate::conversation::*;
|
pub use crate::client::*;
|
||||||
|
pub use crate::dialogue::*;
|
||||||
pub use crate::token_provider::*;
|
pub use crate::token_provider::*;
|
||||||
pub use crate::types::*;
|
pub use crate::types::*;
|
||||||
pub use crate::vertex_client::*;
|
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user