Add base code and examples for the Vertex API

- Add base code for the Vertex API with types and a client with helper
  methods for prompting text and conversation.
- Add examples for prompting text and conversation.
This commit is contained in:
2024-02-06 20:50:36 +00:00
parent 995d2537b9
commit 67c6ba878b
10 changed files with 530 additions and 9 deletions

11
.gitignore vendored
View File

@@ -1,14 +1,7 @@
# Generated by Cargo
# will have compiled files and executables
debug/ debug/
target/ target/
# Remove Cargo.lock from gitignore if creating an executable, leave it for libraries
# More information here https://doc.rust-lang.org/cargo/guide/cargo-toml-vs-cargo-lock.html
Cargo.lock Cargo.lock
# These are backup files generated by rustfmt
**/*.rs.bk **/*.rs.bk
# MSVC Windows builds of rustc generate these, which store debugging information
*.pdb *.pdb
/target
.cargo/config.toml

19
Cargo.toml Normal file
View File

@@ -0,0 +1,19 @@
[package]
name = "gemini-rs"
version = "0.1.0"
edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
gcp_auth = "0.10.0"
reqwest = { version = "0.11.9", features = ["json", "gzip"] }
serde = { version = "*", features = ["derive"] }
serde_json = { version = "*"}
tracing = "0.1.40"
[dev-dependencies]
console = "0.15.8"
dialoguer = "0.11.0"
indicatif = "0.17.7"
tokio = { version = "1.35.1", features = ["full"] }

61
examples/conversation.rs Normal file
View File

@@ -0,0 +1,61 @@
use std::{sync::Arc, time::Duration};
use console::style;
use dialoguer::{theme::ColorfulTheme, Input};
use gemini_rs::prelude::*;
use gcp_auth::AuthenticationManager;
use indicatif::{ProgressBar, ProgressStyle};
#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
let authentication_manager = Arc::new(AuthenticationManager::new().await?);
let api_endpoint = std::env::var("API_ENDPOINT")?;
let project_id = std::env::var("PROJECT_ID")?;
let location_id = std::env::var("LOCATION_ID")?;
let vertex_client = VertexClient::new(
authentication_manager,
api_endpoint,
project_id,
location_id,
);
let mut conversation = Conversation::new();
loop {
let message: String = Input::with_theme(&ColorfulTheme::default())
.with_prompt("user")
.interact_text()?;
// Exit the conversation if the user types "exit"
if message == "exit" {
break;
}
// Push the user's message to the conversation.
conversation.push_message(Message::new(Role::User, &message));
// Show a spinner while the model is thinking.
let progress = ProgressBar::new_spinner();
progress.enable_steady_tick(Duration::from_millis(120));
progress.set_style(ProgressStyle::with_template("{spinner:.green} {msg}")?);
progress.set_message("Thinking...");
// Prompt the model with the conversation so far.
let response = vertex_client.prompt_conversation(&conversation).await?;
// Stop the spinner and clear the terminal.
progress.finish_and_clear();
// Print the model's response.
println!(
"{} {} {}",
style(response.role.to_string()).bold(),
style("·").dim(),
style(&response.text).cyan()
);
conversation.push_message(response);
}
Ok(())
}

View File

@@ -0,0 +1,26 @@
use std::sync::Arc;
use gemini_rs::prelude::*;
use gcp_auth::AuthenticationManager;
#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
let authentication_manager = Arc::new(AuthenticationManager::new().await?);
let api_endpoint = std::env::var("API_ENDPOINT")?;
let project_id = std::env::var("PROJECT_ID")?;
let location_id = std::env::var("LOCATION_ID")?;
let vertex_client = VertexClient::new(
authentication_manager,
api_endpoint,
project_id,
location_id,
);
let prompt = "What is the airspeed of an unladen swallow?";
let result = vertex_client.prompt_text(prompt, None).await?;
println!("Response: {}", result);
Ok(())
}

60
src/conversation.rs Normal file
View File

@@ -0,0 +1,60 @@
use std::str::FromStr;
use serde::{Deserialize, Serialize};
#[derive(Clone, Debug, Serialize, Deserialize)]
pub enum Role {
User,
Model,
}
impl ToString for Role {
fn to_string(&self) -> String {
match self {
Role::User => "user".to_string(),
Role::Model => "model".to_string(),
}
}
}
impl FromStr for Role {
type Err = ();
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"user" => Ok(Role::User),
"model" => Ok(Role::Model),
_ => Err(()),
}
}
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct Message {
pub role: Role,
pub text: String,
}
impl Message {
pub fn new(role: Role, text: &str) -> Self {
Message {
role,
text: text.to_string(),
}
}
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct Conversation {
pub messages: Vec<Message>,
}
impl Conversation {
pub fn new() -> Self {
Conversation { messages: vec![] }
}
pub fn push_message(&mut self, message: Message) {
self.messages.push(message);
}
}

40
src/error.rs Normal file
View File

@@ -0,0 +1,40 @@
use std::fmt::Display;
pub type Result<T> = std::result::Result<T, Error>;
#[derive(Debug)]
pub enum Error {
Env(std::env::VarError),
HttpClient(reqwest::Error),
Token(gcp_auth::Error),
}
impl Display for Error {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match &self {
Error::Env(e) => write!(f, "Environment variable error: {}", e),
Error::HttpClient(e) => write!(f, "HTTP Client error: {}", e),
Error::Token(e) => write!(f, "Token error: {}", e),
}
}
}
impl std::error::Error for Error {}
impl From<reqwest::Error> for Error {
fn from(e: reqwest::Error) -> Self {
Error::HttpClient(e)
}
}
impl From<std::env::VarError> for Error {
fn from(e: std::env::VarError) -> Self {
Error::Env(e)
}
}
impl From<gcp_auth::Error> for Error {
fn from(e: gcp_auth::Error) -> Self {
Error::Token(e)
}
}

12
src/lib.rs Normal file
View File

@@ -0,0 +1,12 @@
mod conversation;
pub mod error;
mod token_provider;
mod types;
mod vertex_client;
pub mod prelude {
pub use crate::conversation::*;
pub use crate::token_provider::*;
pub use crate::types::*;
pub use crate::vertex_client::*;
}

28
src/token_provider.rs Normal file
View File

@@ -0,0 +1,28 @@
use std::sync::Arc;
use gcp_auth::AuthenticationManager;
use crate::error::Result;
pub trait TokenProvider {
fn get_token(&self, scope: &[&str])
-> impl std::future::Future<Output = Result<String>> + Send;
}
impl TokenProvider for Arc<AuthenticationManager> {
async fn get_token(&self, scope: &[&str]) -> Result<String> {
match AuthenticationManager::get_token(self, scope).await {
Ok(token) => Ok(token.as_str().to_string()),
Err(e) => Err(e.into()),
}
}
}
impl TokenProvider for AuthenticationManager {
async fn get_token(&self, scope: &[&str]) -> Result<String> {
match AuthenticationManager::get_token(self, scope).await {
Ok(token) => Ok(token.as_str().to_string()),
Err(e) => Err(e.into()),
}
}
}

130
src/types.rs Normal file
View File

@@ -0,0 +1,130 @@
use std::collections::HashMap;
use serde::{Deserialize, Serialize};
#[derive(Debug, Serialize, Deserialize)]
pub struct CountTokensRequest {
pub contents: Content,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct CountTokensResponse {
pub total_tokens: i32,
}
#[derive(Serialize, Deserialize)]
pub struct GenerateContentRequest {
pub contents: Vec<Content>,
pub generation_config: Option<GenerationConfig>,
pub tools: Option<Vec<Tools>>,
}
#[derive(Serialize, Deserialize)]
pub struct Tools {
pub function_declarations: Option<Vec<FunctionDeclaration>>,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct Content {
pub role: String,
pub parts: Vec<Part>,
}
#[derive(Clone, Debug, Serialize, Deserialize, Default)]
#[serde(rename_all = "camelCase")]
pub struct GenerationConfig {
pub max_output_tokens: Option<i32>,
pub temperature: Option<f32>,
pub top_p: Option<f32>,
pub top_k: Option<i32>,
pub stop_sequences: Option<Vec<String>>,
pub candidate_count: Option<u32>,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub enum Part {
Text(String),
InlineData {
mime_type: String,
data: String,
},
FileData {
mime_type: String,
file_uri: String,
},
FunctionCall {
name: String,
args: HashMap<String, String>,
},
}
#[derive(Debug, Serialize, Deserialize)]
pub struct GenerateContentResponse(pub Vec<ResponseStreamChunk>);
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct ResponseStreamChunk {
pub candidates: Vec<Candidate>,
pub usage_metadata: Option<UsageMetadata>,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct Candidate {
pub content: Content,
pub citation_metadata: Option<CitationMetadata>,
pub safety_ratings: Vec<SafetyRating>,
pub finish_reason: Option<String>,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct SafetyRating {
pub category: String,
pub probability: String,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct Citation {
pub start_index: i32,
pub end_index: i32,
pub uri: String,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct CitationMetadata {
pub citations: Vec<Citation>,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct UsageMetadata {
pub candidates_token_count: Option<i32>,
pub prompt_token_count: i32,
pub total_token_count: i32,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct FunctionDeclaration {
pub name: String,
pub description: String,
pub parameters: FunctionParameters,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct FunctionParameters {
pub r#type: String,
pub properties: HashMap<String, FunctionParametersProperty>,
pub required: Vec<String>,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct FunctionParametersProperty {
pub r#type: String,
pub description: String,
}

152
src/vertex_client.rs Normal file
View File

@@ -0,0 +1,152 @@
use crate::conversation::{Message, Role};
use crate::error::Result;
use crate::prelude::{
Content, Conversation, GenerateContentRequest, GenerateContentResponse, GenerationConfig,
};
use crate::{prelude::Part, token_provider::TokenProvider};
pub static AUTH_SCOPE: &[&str] = &["https://www.googleapis.com/auth/cloud-platform"];
pub enum Model {
GeminiPro,
}
impl ToString for Model {
fn to_string(&self) -> String {
match self {
Model::GeminiPro => "gemini-pro".to_string(),
}
}
}
#[derive(Clone)]
pub struct VertexClient<T: TokenProvider + Clone> {
token_provider: T,
client: reqwest::Client,
api_endpoint: String,
project_id: String,
location_id: String,
}
unsafe impl<T: TokenProvider + Clone> Send for VertexClient<T> {}
unsafe impl<T: TokenProvider + Clone> Sync for VertexClient<T> {}
impl<T: TokenProvider + Clone> VertexClient<T> {
pub fn new(
token_provider: T,
api_endpoint: String,
project_id: String,
location_id: String,
) -> Self {
VertexClient {
token_provider,
client: reqwest::Client::new(),
api_endpoint,
project_id,
location_id,
}
}
pub async fn stream_generate_content(
&self,
request: &GenerateContentRequest,
model: Model,
) -> Result<GenerateContentResponse> {
let access_token = self.token_provider.get_token(AUTH_SCOPE).await?;
let endpoint_url: String = format!(
"https://{}/v1beta1/projects/{}/locations/{}/publishers/google/models/{}:streamGenerateContent", self.api_endpoint, self.project_id, self.location_id, model.to_string(),
);
let resp = self
.client
.post(&endpoint_url)
.bearer_auth(access_token)
.json(&request)
.send()
.await?;
let txt_json = resp.text().await?;
tracing::debug!("Vertex API Response: {}", txt_json);
Ok(serde_json::from_str(&txt_json).unwrap())
}
/// Prompts a conversation to the model.
pub async fn prompt_conversation(&self, conversation: &Conversation) -> Result<Message> {
let request = GenerateContentRequest {
contents: conversation
.messages
.iter()
.map(|m| Content {
role: m.role.to_string(),
parts: vec![Part::Text(m.text.clone())],
})
.collect(),
generation_config: None,
tools: None,
};
let response = self
.stream_generate_content(&request, Model::GeminiPro)
.await?;
let text = response
.0
.into_iter()
.flat_map(|chunk| {
chunk.candidates.into_iter().flat_map(|candidate| {
candidate
.content
.parts
.into_iter()
.map(|part| match part {
Part::Text(text) => Some(text),
_ => None,
})
.filter(Option::is_some)
.flatten()
})
})
.collect::<String>();
Ok(Message::new(Role::Model, &text))
}
/// Sends a text prompt to the Vertex API using the Gemini Pro model and extracts the text
/// from the response.
pub async fn prompt_text(
&self,
prompt: &str,
generation_config: Option<&GenerationConfig>,
) -> Result<String> {
let request = GenerateContentRequest {
contents: vec![Content {
role: "user".to_string(),
parts: vec![Part::Text(prompt.to_string())],
}],
generation_config: generation_config.cloned(),
tools: None,
};
let response = self
.stream_generate_content(&request, Model::GeminiPro)
.await?;
let text = response
.0
.into_iter()
.flat_map(|chunk| {
chunk.candidates.into_iter().flat_map(|candidate| {
candidate
.content
.parts
.into_iter()
.map(|part| match part {
Part::Text(text) => Some(text),
_ => None,
})
.filter(Option::is_some)
.flatten()
})
})
.collect::<String>();
Ok(text)
}
}