Add base code and examples for the Vertex API
- Add base code for the Vertex API with types and a client with helper methods for prompting text and conversation. - Add examples for prompting text and conversation.
This commit is contained in:
60
src/conversation.rs
Normal file
60
src/conversation.rs
Normal file
@@ -0,0 +1,60 @@
|
||||
use std::str::FromStr;
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub enum Role {
|
||||
User,
|
||||
Model,
|
||||
}
|
||||
|
||||
impl ToString for Role {
|
||||
fn to_string(&self) -> String {
|
||||
match self {
|
||||
Role::User => "user".to_string(),
|
||||
Role::Model => "model".to_string(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl FromStr for Role {
|
||||
type Err = ();
|
||||
|
||||
fn from_str(s: &str) -> Result<Self, Self::Err> {
|
||||
match s {
|
||||
"user" => Ok(Role::User),
|
||||
"model" => Ok(Role::Model),
|
||||
_ => Err(()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct Message {
|
||||
pub role: Role,
|
||||
pub text: String,
|
||||
}
|
||||
|
||||
impl Message {
|
||||
pub fn new(role: Role, text: &str) -> Self {
|
||||
Message {
|
||||
role,
|
||||
text: text.to_string(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct Conversation {
|
||||
pub messages: Vec<Message>,
|
||||
}
|
||||
|
||||
impl Conversation {
|
||||
pub fn new() -> Self {
|
||||
Conversation { messages: vec![] }
|
||||
}
|
||||
|
||||
pub fn push_message(&mut self, message: Message) {
|
||||
self.messages.push(message);
|
||||
}
|
||||
}
|
||||
40
src/error.rs
Normal file
40
src/error.rs
Normal file
@@ -0,0 +1,40 @@
|
||||
use std::fmt::Display;
|
||||
|
||||
pub type Result<T> = std::result::Result<T, Error>;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum Error {
|
||||
Env(std::env::VarError),
|
||||
HttpClient(reqwest::Error),
|
||||
Token(gcp_auth::Error),
|
||||
}
|
||||
|
||||
impl Display for Error {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match &self {
|
||||
Error::Env(e) => write!(f, "Environment variable error: {}", e),
|
||||
Error::HttpClient(e) => write!(f, "HTTP Client error: {}", e),
|
||||
Error::Token(e) => write!(f, "Token error: {}", e),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::error::Error for Error {}
|
||||
|
||||
impl From<reqwest::Error> for Error {
|
||||
fn from(e: reqwest::Error) -> Self {
|
||||
Error::HttpClient(e)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<std::env::VarError> for Error {
|
||||
fn from(e: std::env::VarError) -> Self {
|
||||
Error::Env(e)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<gcp_auth::Error> for Error {
|
||||
fn from(e: gcp_auth::Error) -> Self {
|
||||
Error::Token(e)
|
||||
}
|
||||
}
|
||||
12
src/lib.rs
Normal file
12
src/lib.rs
Normal file
@@ -0,0 +1,12 @@
|
||||
mod conversation;
|
||||
pub mod error;
|
||||
mod token_provider;
|
||||
mod types;
|
||||
mod vertex_client;
|
||||
|
||||
pub mod prelude {
|
||||
pub use crate::conversation::*;
|
||||
pub use crate::token_provider::*;
|
||||
pub use crate::types::*;
|
||||
pub use crate::vertex_client::*;
|
||||
}
|
||||
28
src/token_provider.rs
Normal file
28
src/token_provider.rs
Normal file
@@ -0,0 +1,28 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use gcp_auth::AuthenticationManager;
|
||||
|
||||
use crate::error::Result;
|
||||
|
||||
pub trait TokenProvider {
|
||||
fn get_token(&self, scope: &[&str])
|
||||
-> impl std::future::Future<Output = Result<String>> + Send;
|
||||
}
|
||||
|
||||
impl TokenProvider for Arc<AuthenticationManager> {
|
||||
async fn get_token(&self, scope: &[&str]) -> Result<String> {
|
||||
match AuthenticationManager::get_token(self, scope).await {
|
||||
Ok(token) => Ok(token.as_str().to_string()),
|
||||
Err(e) => Err(e.into()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl TokenProvider for AuthenticationManager {
|
||||
async fn get_token(&self, scope: &[&str]) -> Result<String> {
|
||||
match AuthenticationManager::get_token(self, scope).await {
|
||||
Ok(token) => Ok(token.as_str().to_string()),
|
||||
Err(e) => Err(e.into()),
|
||||
}
|
||||
}
|
||||
}
|
||||
130
src/types.rs
Normal file
130
src/types.rs
Normal file
@@ -0,0 +1,130 @@
|
||||
use std::collections::HashMap;
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct CountTokensRequest {
|
||||
pub contents: Content,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct CountTokensResponse {
|
||||
pub total_tokens: i32,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub struct GenerateContentRequest {
|
||||
pub contents: Vec<Content>,
|
||||
pub generation_config: Option<GenerationConfig>,
|
||||
pub tools: Option<Vec<Tools>>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub struct Tools {
|
||||
pub function_declarations: Option<Vec<FunctionDeclaration>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct Content {
|
||||
pub role: String,
|
||||
pub parts: Vec<Part>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize, Default)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct GenerationConfig {
|
||||
pub max_output_tokens: Option<i32>,
|
||||
pub temperature: Option<f32>,
|
||||
pub top_p: Option<f32>,
|
||||
pub top_k: Option<i32>,
|
||||
pub stop_sequences: Option<Vec<String>>,
|
||||
pub candidate_count: Option<u32>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub enum Part {
|
||||
Text(String),
|
||||
InlineData {
|
||||
mime_type: String,
|
||||
data: String,
|
||||
},
|
||||
FileData {
|
||||
mime_type: String,
|
||||
file_uri: String,
|
||||
},
|
||||
FunctionCall {
|
||||
name: String,
|
||||
args: HashMap<String, String>,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct GenerateContentResponse(pub Vec<ResponseStreamChunk>);
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct ResponseStreamChunk {
|
||||
pub candidates: Vec<Candidate>,
|
||||
pub usage_metadata: Option<UsageMetadata>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct Candidate {
|
||||
pub content: Content,
|
||||
pub citation_metadata: Option<CitationMetadata>,
|
||||
pub safety_ratings: Vec<SafetyRating>,
|
||||
pub finish_reason: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct SafetyRating {
|
||||
pub category: String,
|
||||
pub probability: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct Citation {
|
||||
pub start_index: i32,
|
||||
pub end_index: i32,
|
||||
pub uri: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct CitationMetadata {
|
||||
pub citations: Vec<Citation>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct UsageMetadata {
|
||||
pub candidates_token_count: Option<i32>,
|
||||
pub prompt_token_count: i32,
|
||||
pub total_token_count: i32,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct FunctionDeclaration {
|
||||
pub name: String,
|
||||
pub description: String,
|
||||
pub parameters: FunctionParameters,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct FunctionParameters {
|
||||
pub r#type: String,
|
||||
pub properties: HashMap<String, FunctionParametersProperty>,
|
||||
pub required: Vec<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct FunctionParametersProperty {
|
||||
pub r#type: String,
|
||||
pub description: String,
|
||||
}
|
||||
152
src/vertex_client.rs
Normal file
152
src/vertex_client.rs
Normal file
@@ -0,0 +1,152 @@
|
||||
use crate::conversation::{Message, Role};
|
||||
use crate::error::Result;
|
||||
use crate::prelude::{
|
||||
Content, Conversation, GenerateContentRequest, GenerateContentResponse, GenerationConfig,
|
||||
};
|
||||
use crate::{prelude::Part, token_provider::TokenProvider};
|
||||
|
||||
pub static AUTH_SCOPE: &[&str] = &["https://www.googleapis.com/auth/cloud-platform"];
|
||||
|
||||
pub enum Model {
|
||||
GeminiPro,
|
||||
}
|
||||
|
||||
impl ToString for Model {
|
||||
fn to_string(&self) -> String {
|
||||
match self {
|
||||
Model::GeminiPro => "gemini-pro".to_string(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct VertexClient<T: TokenProvider + Clone> {
|
||||
token_provider: T,
|
||||
client: reqwest::Client,
|
||||
api_endpoint: String,
|
||||
project_id: String,
|
||||
location_id: String,
|
||||
}
|
||||
|
||||
unsafe impl<T: TokenProvider + Clone> Send for VertexClient<T> {}
|
||||
unsafe impl<T: TokenProvider + Clone> Sync for VertexClient<T> {}
|
||||
|
||||
impl<T: TokenProvider + Clone> VertexClient<T> {
|
||||
pub fn new(
|
||||
token_provider: T,
|
||||
api_endpoint: String,
|
||||
project_id: String,
|
||||
location_id: String,
|
||||
) -> Self {
|
||||
VertexClient {
|
||||
token_provider,
|
||||
client: reqwest::Client::new(),
|
||||
api_endpoint,
|
||||
project_id,
|
||||
location_id,
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn stream_generate_content(
|
||||
&self,
|
||||
request: &GenerateContentRequest,
|
||||
model: Model,
|
||||
) -> Result<GenerateContentResponse> {
|
||||
let access_token = self.token_provider.get_token(AUTH_SCOPE).await?;
|
||||
let endpoint_url: String = format!(
|
||||
"https://{}/v1beta1/projects/{}/locations/{}/publishers/google/models/{}:streamGenerateContent", self.api_endpoint, self.project_id, self.location_id, model.to_string(),
|
||||
);
|
||||
let resp = self
|
||||
.client
|
||||
.post(&endpoint_url)
|
||||
.bearer_auth(access_token)
|
||||
.json(&request)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
let txt_json = resp.text().await?;
|
||||
tracing::debug!("Vertex API Response: {}", txt_json);
|
||||
Ok(serde_json::from_str(&txt_json).unwrap())
|
||||
}
|
||||
|
||||
/// Prompts a conversation to the model.
|
||||
pub async fn prompt_conversation(&self, conversation: &Conversation) -> Result<Message> {
|
||||
let request = GenerateContentRequest {
|
||||
contents: conversation
|
||||
.messages
|
||||
.iter()
|
||||
.map(|m| Content {
|
||||
role: m.role.to_string(),
|
||||
parts: vec![Part::Text(m.text.clone())],
|
||||
})
|
||||
.collect(),
|
||||
generation_config: None,
|
||||
tools: None,
|
||||
};
|
||||
|
||||
let response = self
|
||||
.stream_generate_content(&request, Model::GeminiPro)
|
||||
.await?;
|
||||
|
||||
let text = response
|
||||
.0
|
||||
.into_iter()
|
||||
.flat_map(|chunk| {
|
||||
chunk.candidates.into_iter().flat_map(|candidate| {
|
||||
candidate
|
||||
.content
|
||||
.parts
|
||||
.into_iter()
|
||||
.map(|part| match part {
|
||||
Part::Text(text) => Some(text),
|
||||
_ => None,
|
||||
})
|
||||
.filter(Option::is_some)
|
||||
.flatten()
|
||||
})
|
||||
})
|
||||
.collect::<String>();
|
||||
Ok(Message::new(Role::Model, &text))
|
||||
}
|
||||
|
||||
/// Sends a text prompt to the Vertex API using the Gemini Pro model and extracts the text
|
||||
/// from the response.
|
||||
pub async fn prompt_text(
|
||||
&self,
|
||||
prompt: &str,
|
||||
generation_config: Option<&GenerationConfig>,
|
||||
) -> Result<String> {
|
||||
let request = GenerateContentRequest {
|
||||
contents: vec![Content {
|
||||
role: "user".to_string(),
|
||||
parts: vec![Part::Text(prompt.to_string())],
|
||||
}],
|
||||
generation_config: generation_config.cloned(),
|
||||
tools: None,
|
||||
};
|
||||
|
||||
let response = self
|
||||
.stream_generate_content(&request, Model::GeminiPro)
|
||||
.await?;
|
||||
|
||||
let text = response
|
||||
.0
|
||||
.into_iter()
|
||||
.flat_map(|chunk| {
|
||||
chunk.candidates.into_iter().flat_map(|candidate| {
|
||||
candidate
|
||||
.content
|
||||
.parts
|
||||
.into_iter()
|
||||
.map(|part| match part {
|
||||
Part::Text(text) => Some(text),
|
||||
_ => None,
|
||||
})
|
||||
.filter(Option::is_some)
|
||||
.flatten()
|
||||
})
|
||||
})
|
||||
.collect::<String>();
|
||||
Ok(text)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user