Add base code and examples for the Vertex API
- Add base code for the Vertex API with types and a client with helper methods for prompting text and conversation. - Add examples for prompting text and conversation.
This commit is contained in:
11
.gitignore
vendored
11
.gitignore
vendored
@@ -1,14 +1,7 @@
|
||||
# Generated by Cargo
|
||||
# will have compiled files and executables
|
||||
debug/
|
||||
target/
|
||||
|
||||
# Remove Cargo.lock from gitignore if creating an executable, leave it for libraries
|
||||
# More information here https://doc.rust-lang.org/cargo/guide/cargo-toml-vs-cargo-lock.html
|
||||
Cargo.lock
|
||||
|
||||
# These are backup files generated by rustfmt
|
||||
**/*.rs.bk
|
||||
|
||||
# MSVC Windows builds of rustc generate these, which store debugging information
|
||||
*.pdb
|
||||
/target
|
||||
.cargo/config.toml
|
||||
|
||||
19
Cargo.toml
Normal file
19
Cargo.toml
Normal file
@@ -0,0 +1,19 @@
|
||||
[package]
|
||||
name = "gemini-rs"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||
|
||||
[dependencies]
|
||||
gcp_auth = "0.10.0"
|
||||
reqwest = { version = "0.11.9", features = ["json", "gzip"] }
|
||||
serde = { version = "*", features = ["derive"] }
|
||||
serde_json = { version = "*"}
|
||||
tracing = "0.1.40"
|
||||
|
||||
[dev-dependencies]
|
||||
console = "0.15.8"
|
||||
dialoguer = "0.11.0"
|
||||
indicatif = "0.17.7"
|
||||
tokio = { version = "1.35.1", features = ["full"] }
|
||||
61
examples/conversation.rs
Normal file
61
examples/conversation.rs
Normal file
@@ -0,0 +1,61 @@
|
||||
use std::{sync::Arc, time::Duration};
|
||||
|
||||
use console::style;
|
||||
use dialoguer::{theme::ColorfulTheme, Input};
|
||||
use gemini_rs::prelude::*;
|
||||
|
||||
use gcp_auth::AuthenticationManager;
|
||||
use indicatif::{ProgressBar, ProgressStyle};
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
let authentication_manager = Arc::new(AuthenticationManager::new().await?);
|
||||
let api_endpoint = std::env::var("API_ENDPOINT")?;
|
||||
let project_id = std::env::var("PROJECT_ID")?;
|
||||
let location_id = std::env::var("LOCATION_ID")?;
|
||||
|
||||
let vertex_client = VertexClient::new(
|
||||
authentication_manager,
|
||||
api_endpoint,
|
||||
project_id,
|
||||
location_id,
|
||||
);
|
||||
|
||||
let mut conversation = Conversation::new();
|
||||
loop {
|
||||
let message: String = Input::with_theme(&ColorfulTheme::default())
|
||||
.with_prompt("user")
|
||||
.interact_text()?;
|
||||
|
||||
// Exit the conversation if the user types "exit"
|
||||
if message == "exit" {
|
||||
break;
|
||||
}
|
||||
|
||||
// Push the user's message to the conversation.
|
||||
conversation.push_message(Message::new(Role::User, &message));
|
||||
|
||||
// Show a spinner while the model is thinking.
|
||||
let progress = ProgressBar::new_spinner();
|
||||
progress.enable_steady_tick(Duration::from_millis(120));
|
||||
progress.set_style(ProgressStyle::with_template("{spinner:.green} {msg}")?);
|
||||
progress.set_message("Thinking...");
|
||||
|
||||
// Prompt the model with the conversation so far.
|
||||
let response = vertex_client.prompt_conversation(&conversation).await?;
|
||||
|
||||
// Stop the spinner and clear the terminal.
|
||||
progress.finish_and_clear();
|
||||
|
||||
// Print the model's response.
|
||||
println!(
|
||||
"✨ {} {} {}",
|
||||
style(response.role.to_string()).bold(),
|
||||
style("·").dim(),
|
||||
style(&response.text).cyan()
|
||||
);
|
||||
conversation.push_message(response);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
26
examples/text-from-text.rs
Normal file
26
examples/text-from-text.rs
Normal file
@@ -0,0 +1,26 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use gemini_rs::prelude::*;
|
||||
|
||||
use gcp_auth::AuthenticationManager;
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
let authentication_manager = Arc::new(AuthenticationManager::new().await?);
|
||||
let api_endpoint = std::env::var("API_ENDPOINT")?;
|
||||
let project_id = std::env::var("PROJECT_ID")?;
|
||||
let location_id = std::env::var("LOCATION_ID")?;
|
||||
|
||||
let vertex_client = VertexClient::new(
|
||||
authentication_manager,
|
||||
api_endpoint,
|
||||
project_id,
|
||||
location_id,
|
||||
);
|
||||
|
||||
let prompt = "What is the airspeed of an unladen swallow?";
|
||||
let result = vertex_client.prompt_text(prompt, None).await?;
|
||||
println!("Response: {}", result);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
60
src/conversation.rs
Normal file
60
src/conversation.rs
Normal file
@@ -0,0 +1,60 @@
|
||||
use std::str::FromStr;
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub enum Role {
|
||||
User,
|
||||
Model,
|
||||
}
|
||||
|
||||
impl ToString for Role {
|
||||
fn to_string(&self) -> String {
|
||||
match self {
|
||||
Role::User => "user".to_string(),
|
||||
Role::Model => "model".to_string(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl FromStr for Role {
|
||||
type Err = ();
|
||||
|
||||
fn from_str(s: &str) -> Result<Self, Self::Err> {
|
||||
match s {
|
||||
"user" => Ok(Role::User),
|
||||
"model" => Ok(Role::Model),
|
||||
_ => Err(()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct Message {
|
||||
pub role: Role,
|
||||
pub text: String,
|
||||
}
|
||||
|
||||
impl Message {
|
||||
pub fn new(role: Role, text: &str) -> Self {
|
||||
Message {
|
||||
role,
|
||||
text: text.to_string(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct Conversation {
|
||||
pub messages: Vec<Message>,
|
||||
}
|
||||
|
||||
impl Conversation {
|
||||
pub fn new() -> Self {
|
||||
Conversation { messages: vec![] }
|
||||
}
|
||||
|
||||
pub fn push_message(&mut self, message: Message) {
|
||||
self.messages.push(message);
|
||||
}
|
||||
}
|
||||
40
src/error.rs
Normal file
40
src/error.rs
Normal file
@@ -0,0 +1,40 @@
|
||||
use std::fmt::Display;
|
||||
|
||||
pub type Result<T> = std::result::Result<T, Error>;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum Error {
|
||||
Env(std::env::VarError),
|
||||
HttpClient(reqwest::Error),
|
||||
Token(gcp_auth::Error),
|
||||
}
|
||||
|
||||
impl Display for Error {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match &self {
|
||||
Error::Env(e) => write!(f, "Environment variable error: {}", e),
|
||||
Error::HttpClient(e) => write!(f, "HTTP Client error: {}", e),
|
||||
Error::Token(e) => write!(f, "Token error: {}", e),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::error::Error for Error {}
|
||||
|
||||
impl From<reqwest::Error> for Error {
|
||||
fn from(e: reqwest::Error) -> Self {
|
||||
Error::HttpClient(e)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<std::env::VarError> for Error {
|
||||
fn from(e: std::env::VarError) -> Self {
|
||||
Error::Env(e)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<gcp_auth::Error> for Error {
|
||||
fn from(e: gcp_auth::Error) -> Self {
|
||||
Error::Token(e)
|
||||
}
|
||||
}
|
||||
12
src/lib.rs
Normal file
12
src/lib.rs
Normal file
@@ -0,0 +1,12 @@
|
||||
mod conversation;
|
||||
pub mod error;
|
||||
mod token_provider;
|
||||
mod types;
|
||||
mod vertex_client;
|
||||
|
||||
pub mod prelude {
|
||||
pub use crate::conversation::*;
|
||||
pub use crate::token_provider::*;
|
||||
pub use crate::types::*;
|
||||
pub use crate::vertex_client::*;
|
||||
}
|
||||
28
src/token_provider.rs
Normal file
28
src/token_provider.rs
Normal file
@@ -0,0 +1,28 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use gcp_auth::AuthenticationManager;
|
||||
|
||||
use crate::error::Result;
|
||||
|
||||
pub trait TokenProvider {
|
||||
fn get_token(&self, scope: &[&str])
|
||||
-> impl std::future::Future<Output = Result<String>> + Send;
|
||||
}
|
||||
|
||||
impl TokenProvider for Arc<AuthenticationManager> {
|
||||
async fn get_token(&self, scope: &[&str]) -> Result<String> {
|
||||
match AuthenticationManager::get_token(self, scope).await {
|
||||
Ok(token) => Ok(token.as_str().to_string()),
|
||||
Err(e) => Err(e.into()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl TokenProvider for AuthenticationManager {
|
||||
async fn get_token(&self, scope: &[&str]) -> Result<String> {
|
||||
match AuthenticationManager::get_token(self, scope).await {
|
||||
Ok(token) => Ok(token.as_str().to_string()),
|
||||
Err(e) => Err(e.into()),
|
||||
}
|
||||
}
|
||||
}
|
||||
130
src/types.rs
Normal file
130
src/types.rs
Normal file
@@ -0,0 +1,130 @@
|
||||
use std::collections::HashMap;
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct CountTokensRequest {
|
||||
pub contents: Content,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct CountTokensResponse {
|
||||
pub total_tokens: i32,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub struct GenerateContentRequest {
|
||||
pub contents: Vec<Content>,
|
||||
pub generation_config: Option<GenerationConfig>,
|
||||
pub tools: Option<Vec<Tools>>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub struct Tools {
|
||||
pub function_declarations: Option<Vec<FunctionDeclaration>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct Content {
|
||||
pub role: String,
|
||||
pub parts: Vec<Part>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize, Default)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct GenerationConfig {
|
||||
pub max_output_tokens: Option<i32>,
|
||||
pub temperature: Option<f32>,
|
||||
pub top_p: Option<f32>,
|
||||
pub top_k: Option<i32>,
|
||||
pub stop_sequences: Option<Vec<String>>,
|
||||
pub candidate_count: Option<u32>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub enum Part {
|
||||
Text(String),
|
||||
InlineData {
|
||||
mime_type: String,
|
||||
data: String,
|
||||
},
|
||||
FileData {
|
||||
mime_type: String,
|
||||
file_uri: String,
|
||||
},
|
||||
FunctionCall {
|
||||
name: String,
|
||||
args: HashMap<String, String>,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct GenerateContentResponse(pub Vec<ResponseStreamChunk>);
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct ResponseStreamChunk {
|
||||
pub candidates: Vec<Candidate>,
|
||||
pub usage_metadata: Option<UsageMetadata>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct Candidate {
|
||||
pub content: Content,
|
||||
pub citation_metadata: Option<CitationMetadata>,
|
||||
pub safety_ratings: Vec<SafetyRating>,
|
||||
pub finish_reason: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct SafetyRating {
|
||||
pub category: String,
|
||||
pub probability: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct Citation {
|
||||
pub start_index: i32,
|
||||
pub end_index: i32,
|
||||
pub uri: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct CitationMetadata {
|
||||
pub citations: Vec<Citation>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct UsageMetadata {
|
||||
pub candidates_token_count: Option<i32>,
|
||||
pub prompt_token_count: i32,
|
||||
pub total_token_count: i32,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct FunctionDeclaration {
|
||||
pub name: String,
|
||||
pub description: String,
|
||||
pub parameters: FunctionParameters,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct FunctionParameters {
|
||||
pub r#type: String,
|
||||
pub properties: HashMap<String, FunctionParametersProperty>,
|
||||
pub required: Vec<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct FunctionParametersProperty {
|
||||
pub r#type: String,
|
||||
pub description: String,
|
||||
}
|
||||
152
src/vertex_client.rs
Normal file
152
src/vertex_client.rs
Normal file
@@ -0,0 +1,152 @@
|
||||
use crate::conversation::{Message, Role};
|
||||
use crate::error::Result;
|
||||
use crate::prelude::{
|
||||
Content, Conversation, GenerateContentRequest, GenerateContentResponse, GenerationConfig,
|
||||
};
|
||||
use crate::{prelude::Part, token_provider::TokenProvider};
|
||||
|
||||
pub static AUTH_SCOPE: &[&str] = &["https://www.googleapis.com/auth/cloud-platform"];
|
||||
|
||||
pub enum Model {
|
||||
GeminiPro,
|
||||
}
|
||||
|
||||
impl ToString for Model {
|
||||
fn to_string(&self) -> String {
|
||||
match self {
|
||||
Model::GeminiPro => "gemini-pro".to_string(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct VertexClient<T: TokenProvider + Clone> {
|
||||
token_provider: T,
|
||||
client: reqwest::Client,
|
||||
api_endpoint: String,
|
||||
project_id: String,
|
||||
location_id: String,
|
||||
}
|
||||
|
||||
unsafe impl<T: TokenProvider + Clone> Send for VertexClient<T> {}
|
||||
unsafe impl<T: TokenProvider + Clone> Sync for VertexClient<T> {}
|
||||
|
||||
impl<T: TokenProvider + Clone> VertexClient<T> {
|
||||
pub fn new(
|
||||
token_provider: T,
|
||||
api_endpoint: String,
|
||||
project_id: String,
|
||||
location_id: String,
|
||||
) -> Self {
|
||||
VertexClient {
|
||||
token_provider,
|
||||
client: reqwest::Client::new(),
|
||||
api_endpoint,
|
||||
project_id,
|
||||
location_id,
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn stream_generate_content(
|
||||
&self,
|
||||
request: &GenerateContentRequest,
|
||||
model: Model,
|
||||
) -> Result<GenerateContentResponse> {
|
||||
let access_token = self.token_provider.get_token(AUTH_SCOPE).await?;
|
||||
let endpoint_url: String = format!(
|
||||
"https://{}/v1beta1/projects/{}/locations/{}/publishers/google/models/{}:streamGenerateContent", self.api_endpoint, self.project_id, self.location_id, model.to_string(),
|
||||
);
|
||||
let resp = self
|
||||
.client
|
||||
.post(&endpoint_url)
|
||||
.bearer_auth(access_token)
|
||||
.json(&request)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
let txt_json = resp.text().await?;
|
||||
tracing::debug!("Vertex API Response: {}", txt_json);
|
||||
Ok(serde_json::from_str(&txt_json).unwrap())
|
||||
}
|
||||
|
||||
/// Prompts a conversation to the model.
|
||||
pub async fn prompt_conversation(&self, conversation: &Conversation) -> Result<Message> {
|
||||
let request = GenerateContentRequest {
|
||||
contents: conversation
|
||||
.messages
|
||||
.iter()
|
||||
.map(|m| Content {
|
||||
role: m.role.to_string(),
|
||||
parts: vec![Part::Text(m.text.clone())],
|
||||
})
|
||||
.collect(),
|
||||
generation_config: None,
|
||||
tools: None,
|
||||
};
|
||||
|
||||
let response = self
|
||||
.stream_generate_content(&request, Model::GeminiPro)
|
||||
.await?;
|
||||
|
||||
let text = response
|
||||
.0
|
||||
.into_iter()
|
||||
.flat_map(|chunk| {
|
||||
chunk.candidates.into_iter().flat_map(|candidate| {
|
||||
candidate
|
||||
.content
|
||||
.parts
|
||||
.into_iter()
|
||||
.map(|part| match part {
|
||||
Part::Text(text) => Some(text),
|
||||
_ => None,
|
||||
})
|
||||
.filter(Option::is_some)
|
||||
.flatten()
|
||||
})
|
||||
})
|
||||
.collect::<String>();
|
||||
Ok(Message::new(Role::Model, &text))
|
||||
}
|
||||
|
||||
/// Sends a text prompt to the Vertex API using the Gemini Pro model and extracts the text
|
||||
/// from the response.
|
||||
pub async fn prompt_text(
|
||||
&self,
|
||||
prompt: &str,
|
||||
generation_config: Option<&GenerationConfig>,
|
||||
) -> Result<String> {
|
||||
let request = GenerateContentRequest {
|
||||
contents: vec![Content {
|
||||
role: "user".to_string(),
|
||||
parts: vec![Part::Text(prompt.to_string())],
|
||||
}],
|
||||
generation_config: generation_config.cloned(),
|
||||
tools: None,
|
||||
};
|
||||
|
||||
let response = self
|
||||
.stream_generate_content(&request, Model::GeminiPro)
|
||||
.await?;
|
||||
|
||||
let text = response
|
||||
.0
|
||||
.into_iter()
|
||||
.flat_map(|chunk| {
|
||||
chunk.candidates.into_iter().flat_map(|candidate| {
|
||||
candidate
|
||||
.content
|
||||
.parts
|
||||
.into_iter()
|
||||
.map(|part| match part {
|
||||
Part::Text(text) => Some(text),
|
||||
_ => None,
|
||||
})
|
||||
.filter(Option::is_some)
|
||||
.flatten()
|
||||
})
|
||||
})
|
||||
.collect::<String>();
|
||||
Ok(text)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user