From 67c6ba878b73dd7841883e189559cb077c829239 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andr=C3=A9=20Cipriani=20Bandarra?= Date: Tue, 6 Feb 2024 20:50:36 +0000 Subject: [PATCH] Add base code and examples for the Vertex API - Add base code for the Vertex API with types and a client with helper methods for prompting text and conversation. - Add examples for prompting text and conversation. --- .gitignore | 11 +-- Cargo.toml | 19 +++++ examples/conversation.rs | 61 +++++++++++++++ examples/text-from-text.rs | 26 +++++++ src/conversation.rs | 60 +++++++++++++++ src/error.rs | 40 ++++++++++ src/lib.rs | 12 +++ src/token_provider.rs | 28 +++++++ src/types.rs | 130 +++++++++++++++++++++++++++++++ src/vertex_client.rs | 152 +++++++++++++++++++++++++++++++++++++ 10 files changed, 530 insertions(+), 9 deletions(-) create mode 100644 Cargo.toml create mode 100644 examples/conversation.rs create mode 100644 examples/text-from-text.rs create mode 100644 src/conversation.rs create mode 100644 src/error.rs create mode 100644 src/lib.rs create mode 100644 src/token_provider.rs create mode 100644 src/types.rs create mode 100644 src/vertex_client.rs diff --git a/.gitignore b/.gitignore index 6985cf1..f92d917 100644 --- a/.gitignore +++ b/.gitignore @@ -1,14 +1,7 @@ -# Generated by Cargo -# will have compiled files and executables debug/ target/ - -# Remove Cargo.lock from gitignore if creating an executable, leave it for libraries -# More information here https://doc.rust-lang.org/cargo/guide/cargo-toml-vs-cargo-lock.html Cargo.lock - -# These are backup files generated by rustfmt **/*.rs.bk - -# MSVC Windows builds of rustc generate these, which store debugging information *.pdb +/target +.cargo/config.toml diff --git a/Cargo.toml b/Cargo.toml new file mode 100644 index 0000000..2e09bbe --- /dev/null +++ b/Cargo.toml @@ -0,0 +1,19 @@ +[package] +name = "gemini-rs" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +gcp_auth = "0.10.0" +reqwest = { version = "0.11.9", features = ["json", "gzip"] } +serde = { version = "*", features = ["derive"] } +serde_json = { version = "*"} +tracing = "0.1.40" + +[dev-dependencies] +console = "0.15.8" +dialoguer = "0.11.0" +indicatif = "0.17.7" +tokio = { version = "1.35.1", features = ["full"] } diff --git a/examples/conversation.rs b/examples/conversation.rs new file mode 100644 index 0000000..fa45e0f --- /dev/null +++ b/examples/conversation.rs @@ -0,0 +1,61 @@ +use std::{sync::Arc, time::Duration}; + +use console::style; +use dialoguer::{theme::ColorfulTheme, Input}; +use gemini_rs::prelude::*; + +use gcp_auth::AuthenticationManager; +use indicatif::{ProgressBar, ProgressStyle}; + +#[tokio::main] +async fn main() -> Result<(), Box> { + let authentication_manager = Arc::new(AuthenticationManager::new().await?); + let api_endpoint = std::env::var("API_ENDPOINT")?; + let project_id = std::env::var("PROJECT_ID")?; + let location_id = std::env::var("LOCATION_ID")?; + + let vertex_client = VertexClient::new( + authentication_manager, + api_endpoint, + project_id, + location_id, + ); + + let mut conversation = Conversation::new(); + loop { + let message: String = Input::with_theme(&ColorfulTheme::default()) + .with_prompt("user") + .interact_text()?; + + // Exit the conversation if the user types "exit" + if message == "exit" { + break; + } + + // Push the user's message to the conversation. + conversation.push_message(Message::new(Role::User, &message)); + + // Show a spinner while the model is thinking. + let progress = ProgressBar::new_spinner(); + progress.enable_steady_tick(Duration::from_millis(120)); + progress.set_style(ProgressStyle::with_template("{spinner:.green} {msg}")?); + progress.set_message("Thinking..."); + + // Prompt the model with the conversation so far. + let response = vertex_client.prompt_conversation(&conversation).await?; + + // Stop the spinner and clear the terminal. + progress.finish_and_clear(); + + // Print the model's response. + println!( + "✨ {} {} {}", + style(response.role.to_string()).bold(), + style("·").dim(), + style(&response.text).cyan() + ); + conversation.push_message(response); + } + + Ok(()) +} diff --git a/examples/text-from-text.rs b/examples/text-from-text.rs new file mode 100644 index 0000000..b944482 --- /dev/null +++ b/examples/text-from-text.rs @@ -0,0 +1,26 @@ +use std::sync::Arc; + +use gemini_rs::prelude::*; + +use gcp_auth::AuthenticationManager; + +#[tokio::main] +async fn main() -> Result<(), Box> { + let authentication_manager = Arc::new(AuthenticationManager::new().await?); + let api_endpoint = std::env::var("API_ENDPOINT")?; + let project_id = std::env::var("PROJECT_ID")?; + let location_id = std::env::var("LOCATION_ID")?; + + let vertex_client = VertexClient::new( + authentication_manager, + api_endpoint, + project_id, + location_id, + ); + + let prompt = "What is the airspeed of an unladen swallow?"; + let result = vertex_client.prompt_text(prompt, None).await?; + println!("Response: {}", result); + + Ok(()) +} diff --git a/src/conversation.rs b/src/conversation.rs new file mode 100644 index 0000000..0a8c6c2 --- /dev/null +++ b/src/conversation.rs @@ -0,0 +1,60 @@ +use std::str::FromStr; + +use serde::{Deserialize, Serialize}; + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub enum Role { + User, + Model, +} + +impl ToString for Role { + fn to_string(&self) -> String { + match self { + Role::User => "user".to_string(), + Role::Model => "model".to_string(), + } + } +} + +impl FromStr for Role { + type Err = (); + + fn from_str(s: &str) -> Result { + match s { + "user" => Ok(Role::User), + "model" => Ok(Role::Model), + _ => Err(()), + } + } +} + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct Message { + pub role: Role, + pub text: String, +} + +impl Message { + pub fn new(role: Role, text: &str) -> Self { + Message { + role, + text: text.to_string(), + } + } +} + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct Conversation { + pub messages: Vec, +} + +impl Conversation { + pub fn new() -> Self { + Conversation { messages: vec![] } + } + + pub fn push_message(&mut self, message: Message) { + self.messages.push(message); + } +} diff --git a/src/error.rs b/src/error.rs new file mode 100644 index 0000000..ec30308 --- /dev/null +++ b/src/error.rs @@ -0,0 +1,40 @@ +use std::fmt::Display; + +pub type Result = std::result::Result; + +#[derive(Debug)] +pub enum Error { + Env(std::env::VarError), + HttpClient(reqwest::Error), + Token(gcp_auth::Error), +} + +impl Display for Error { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match &self { + Error::Env(e) => write!(f, "Environment variable error: {}", e), + Error::HttpClient(e) => write!(f, "HTTP Client error: {}", e), + Error::Token(e) => write!(f, "Token error: {}", e), + } + } +} + +impl std::error::Error for Error {} + +impl From for Error { + fn from(e: reqwest::Error) -> Self { + Error::HttpClient(e) + } +} + +impl From for Error { + fn from(e: std::env::VarError) -> Self { + Error::Env(e) + } +} + +impl From for Error { + fn from(e: gcp_auth::Error) -> Self { + Error::Token(e) + } +} diff --git a/src/lib.rs b/src/lib.rs new file mode 100644 index 0000000..d95dbe2 --- /dev/null +++ b/src/lib.rs @@ -0,0 +1,12 @@ +mod conversation; +pub mod error; +mod token_provider; +mod types; +mod vertex_client; + +pub mod prelude { + pub use crate::conversation::*; + pub use crate::token_provider::*; + pub use crate::types::*; + pub use crate::vertex_client::*; +} diff --git a/src/token_provider.rs b/src/token_provider.rs new file mode 100644 index 0000000..d506ed3 --- /dev/null +++ b/src/token_provider.rs @@ -0,0 +1,28 @@ +use std::sync::Arc; + +use gcp_auth::AuthenticationManager; + +use crate::error::Result; + +pub trait TokenProvider { + fn get_token(&self, scope: &[&str]) + -> impl std::future::Future> + Send; +} + +impl TokenProvider for Arc { + async fn get_token(&self, scope: &[&str]) -> Result { + match AuthenticationManager::get_token(self, scope).await { + Ok(token) => Ok(token.as_str().to_string()), + Err(e) => Err(e.into()), + } + } +} + +impl TokenProvider for AuthenticationManager { + async fn get_token(&self, scope: &[&str]) -> Result { + match AuthenticationManager::get_token(self, scope).await { + Ok(token) => Ok(token.as_str().to_string()), + Err(e) => Err(e.into()), + } + } +} diff --git a/src/types.rs b/src/types.rs new file mode 100644 index 0000000..0550cfa --- /dev/null +++ b/src/types.rs @@ -0,0 +1,130 @@ +use std::collections::HashMap; + +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Serialize, Deserialize)] +pub struct CountTokensRequest { + pub contents: Content, +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct CountTokensResponse { + pub total_tokens: i32, +} + +#[derive(Serialize, Deserialize)] +pub struct GenerateContentRequest { + pub contents: Vec, + pub generation_config: Option, + pub tools: Option>, +} + +#[derive(Serialize, Deserialize)] +pub struct Tools { + pub function_declarations: Option>, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct Content { + pub role: String, + pub parts: Vec, +} + +#[derive(Clone, Debug, Serialize, Deserialize, Default)] +#[serde(rename_all = "camelCase")] +pub struct GenerationConfig { + pub max_output_tokens: Option, + pub temperature: Option, + pub top_p: Option, + pub top_k: Option, + pub stop_sequences: Option>, + pub candidate_count: Option, +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub enum Part { + Text(String), + InlineData { + mime_type: String, + data: String, + }, + FileData { + mime_type: String, + file_uri: String, + }, + FunctionCall { + name: String, + args: HashMap, + }, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct GenerateContentResponse(pub Vec); + +#[derive(Debug, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ResponseStreamChunk { + pub candidates: Vec, + pub usage_metadata: Option, +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct Candidate { + pub content: Content, + pub citation_metadata: Option, + pub safety_ratings: Vec, + pub finish_reason: Option, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct SafetyRating { + pub category: String, + pub probability: String, +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct Citation { + pub start_index: i32, + pub end_index: i32, + pub uri: String, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct CitationMetadata { + pub citations: Vec, +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct UsageMetadata { + pub candidates_token_count: Option, + pub prompt_token_count: i32, + pub total_token_count: i32, +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct FunctionDeclaration { + pub name: String, + pub description: String, + pub parameters: FunctionParameters, +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct FunctionParameters { + pub r#type: String, + pub properties: HashMap, + pub required: Vec, +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct FunctionParametersProperty { + pub r#type: String, + pub description: String, +} diff --git a/src/vertex_client.rs b/src/vertex_client.rs new file mode 100644 index 0000000..b01cda4 --- /dev/null +++ b/src/vertex_client.rs @@ -0,0 +1,152 @@ +use crate::conversation::{Message, Role}; +use crate::error::Result; +use crate::prelude::{ + Content, Conversation, GenerateContentRequest, GenerateContentResponse, GenerationConfig, +}; +use crate::{prelude::Part, token_provider::TokenProvider}; + +pub static AUTH_SCOPE: &[&str] = &["https://www.googleapis.com/auth/cloud-platform"]; + +pub enum Model { + GeminiPro, +} + +impl ToString for Model { + fn to_string(&self) -> String { + match self { + Model::GeminiPro => "gemini-pro".to_string(), + } + } +} + +#[derive(Clone)] +pub struct VertexClient { + token_provider: T, + client: reqwest::Client, + api_endpoint: String, + project_id: String, + location_id: String, +} + +unsafe impl Send for VertexClient {} +unsafe impl Sync for VertexClient {} + +impl VertexClient { + pub fn new( + token_provider: T, + api_endpoint: String, + project_id: String, + location_id: String, + ) -> Self { + VertexClient { + token_provider, + client: reqwest::Client::new(), + api_endpoint, + project_id, + location_id, + } + } + + pub async fn stream_generate_content( + &self, + request: &GenerateContentRequest, + model: Model, + ) -> Result { + let access_token = self.token_provider.get_token(AUTH_SCOPE).await?; + let endpoint_url: String = format!( + "https://{}/v1beta1/projects/{}/locations/{}/publishers/google/models/{}:streamGenerateContent", self.api_endpoint, self.project_id, self.location_id, model.to_string(), + ); + let resp = self + .client + .post(&endpoint_url) + .bearer_auth(access_token) + .json(&request) + .send() + .await?; + + let txt_json = resp.text().await?; + tracing::debug!("Vertex API Response: {}", txt_json); + Ok(serde_json::from_str(&txt_json).unwrap()) + } + + /// Prompts a conversation to the model. + pub async fn prompt_conversation(&self, conversation: &Conversation) -> Result { + let request = GenerateContentRequest { + contents: conversation + .messages + .iter() + .map(|m| Content { + role: m.role.to_string(), + parts: vec![Part::Text(m.text.clone())], + }) + .collect(), + generation_config: None, + tools: None, + }; + + let response = self + .stream_generate_content(&request, Model::GeminiPro) + .await?; + + let text = response + .0 + .into_iter() + .flat_map(|chunk| { + chunk.candidates.into_iter().flat_map(|candidate| { + candidate + .content + .parts + .into_iter() + .map(|part| match part { + Part::Text(text) => Some(text), + _ => None, + }) + .filter(Option::is_some) + .flatten() + }) + }) + .collect::(); + Ok(Message::new(Role::Model, &text)) + } + + /// Sends a text prompt to the Vertex API using the Gemini Pro model and extracts the text + /// from the response. + pub async fn prompt_text( + &self, + prompt: &str, + generation_config: Option<&GenerationConfig>, + ) -> Result { + let request = GenerateContentRequest { + contents: vec![Content { + role: "user".to_string(), + parts: vec![Part::Text(prompt.to_string())], + }], + generation_config: generation_config.cloned(), + tools: None, + }; + + let response = self + .stream_generate_content(&request, Model::GeminiPro) + .await?; + + let text = response + .0 + .into_iter() + .flat_map(|chunk| { + chunk.candidates.into_iter().flat_map(|candidate| { + candidate + .content + .parts + .into_iter() + .map(|part| match part { + Part::Text(text) => Some(text), + _ => None, + }) + .filter(Option::is_some) + .flatten() + }) + }) + .collect::(); + Ok(text) + } +}