Refactor types into separate files

This commit is contained in:
2024-02-23 18:25:32 +00:00
parent 93285e53dd
commit b8f2b7e85e
7 changed files with 219 additions and 215 deletions

View File

@@ -20,9 +20,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
let prompt = "Tell me the story of the genesis of the universe as a bedtime story."; let prompt = "Tell me the story of the genesis of the universe as a bedtime story.";
let request = GenerateContentRequest::from_prompt(prompt, None); let request = GenerateContentRequest::from_prompt(prompt, None);
let queue = gemini let queue = gemini.stream_generate_content(&request, "gemini-pro").await;
.stream_generate_content(&request, Model::GeminiPro)
.await;
while let Some(response) = queue.pop().await { while let Some(response) = queue.pop().await {
if let GenerateContentResponse::Ok { if let GenerateContentResponse::Ok {

View File

@@ -13,18 +13,6 @@ use crate::{prelude::Part, token_provider::TokenProvider};
pub static AUTH_SCOPE: &[&str] = &["https://www.googleapis.com/auth/cloud-platform"]; pub static AUTH_SCOPE: &[&str] = &["https://www.googleapis.com/auth/cloud-platform"];
pub enum Model {
GeminiPro,
}
impl ToString for Model {
fn to_string(&self) -> String {
match self {
Model::GeminiPro => "gemini-pro".to_string(),
}
}
}
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
pub struct GeminiClient<T: TokenProvider + Clone> { pub struct GeminiClient<T: TokenProvider + Clone> {
token_provider: T, token_provider: T,
@@ -56,7 +44,7 @@ impl<T: TokenProvider + Clone> GeminiClient<T> {
pub async fn stream_generate_content( pub async fn stream_generate_content(
&self, &self,
request: &GenerateContentRequest, request: &GenerateContentRequest,
model: Model, model: &str,
) -> Arc<Queue<Option<GenerateContentResponse>>> { ) -> Arc<Queue<Option<GenerateContentResponse>>> {
let queue = Arc::new(Queue::<Option<GenerateContentResponse>>::new()); let queue = Arc::new(Queue::<Option<GenerateContentResponse>>::new());
@@ -93,7 +81,7 @@ impl<T: TokenProvider + Clone> GeminiClient<T> {
pub async fn generate_content( pub async fn generate_content(
&self, &self,
request: &GenerateContentRequest, request: &GenerateContentRequest,
model: Model, model: &str,
) -> Result<GenerateContentResponse> { ) -> Result<GenerateContentResponse> {
let access_token = self.token_provider.get_token(AUTH_SCOPE).await?; let access_token = self.token_provider.get_token(AUTH_SCOPE).await?;
let endpoint_url: String = format!( let endpoint_url: String = format!(
@@ -131,7 +119,7 @@ impl<T: TokenProvider + Clone> GeminiClient<T> {
tools: None, tools: None,
}; };
let response = self.generate_content(&request, Model::GeminiPro).await?; let response = self.generate_content(&request, "gemini-pro").await?;
// Check for errors in the response. // Check for errors in the response.
let mut candidates = GeminiClient::<T>::collect_text_from_response(&response)?; let mut candidates = GeminiClient::<T>::collect_text_from_response(&response)?;
@@ -158,7 +146,7 @@ impl<T: TokenProvider + Clone> GeminiClient<T> {
tools: None, tools: None,
}; };
let response = self.generate_content(&request, Model::GeminiPro).await?; let response = self.generate_content(&request, "gemini-pro").await?;
let mut candidates = GeminiClient::<T>::collect_text_from_response(&response)?; let mut candidates = GeminiClient::<T>::collect_text_from_response(&response)?;
match candidates.pop() { match candidates.pop() {

41
src/types/common.rs Normal file
View File

@@ -0,0 +1,41 @@
use std::collections::HashMap;
use serde::{Deserialize, Serialize};
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct Content {
pub role: String,
pub parts: Option<Vec<Part>>,
}
impl Content {
pub fn get_text(&self) -> Option<String> {
self.parts.as_ref().map(|parts| {
parts
.iter()
.filter_map(|part| match part {
Part::Text(text) => Some(text.clone()),
_ => None,
})
.collect::<String>()
})
}
}
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub enum Part {
Text(String),
InlineData {
mime_type: String,
data: String,
},
FileData {
mime_type: String,
file_uri: String,
},
FunctionCall {
name: String,
args: HashMap<String, String>,
},
}

14
src/types/count_tokens.rs Normal file
View File

@@ -0,0 +1,14 @@
use serde::{Deserialize, Serialize};
use super::Content;
#[derive(Debug, Serialize, Deserialize)]
pub struct CountTokensRequest {
pub contents: Content,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct CountTokensResponse {
pub total_tokens: i32,
}

31
src/types/error.rs Normal file
View File

@@ -0,0 +1,31 @@
use serde::{Deserialize, Serialize};
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct Error {
pub code: i32,
pub message: String,
pub status: String,
pub details: Vec<ErrorType>,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct Link {
pub description: String,
pub url: String,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(tag = "@type")]
pub enum ErrorType {
#[serde(rename = "type.googleapis.com/google.rpc.ErrorInfo")]
ErrorInfo { metadata: ErrorInfoMetadata },
#[serde(rename = "type.googleapis.com/google.rpc.Help")]
Help { links: Vec<Link> },
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct ErrorInfoMetadata {
service: String,
consumer: String,
}

View File

@@ -1,196 +1,119 @@
use std::collections::HashMap; use std::collections::HashMap;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
#[derive(Debug, Serialize, Deserialize)] use super::{Content, Error, Part};
pub struct CountTokensRequest {
pub contents: Content, #[derive(Clone, Serialize, Deserialize)]
} pub struct GenerateContentRequest {
pub contents: Vec<Content>,
#[derive(Debug, Serialize, Deserialize)] pub generation_config: Option<GenerationConfig>,
#[serde(rename_all = "camelCase")] pub tools: Option<Vec<Tools>>,
pub struct CountTokensResponse { }
pub total_tokens: i32,
} impl GenerateContentRequest {
pub fn from_prompt(prompt: &str, generation_config: Option<GenerationConfig>) -> Self {
#[derive(Clone, Serialize, Deserialize)] GenerateContentRequest {
pub struct GenerateContentRequest { contents: vec![Content {
pub contents: Vec<Content>, role: "user".to_string(),
pub generation_config: Option<GenerationConfig>, parts: Some(vec![Part::Text(prompt.to_string())]),
pub tools: Option<Vec<Tools>>, }],
} generation_config,
tools: None,
impl GenerateContentRequest { }
pub fn from_prompt(prompt: &str, generation_config: Option<GenerationConfig>) -> Self { }
GenerateContentRequest { }
contents: vec![Content {
role: "user".to_string(), #[derive(Clone, Serialize, Deserialize)]
parts: Some(vec![Part::Text(prompt.to_string())]), pub struct Tools {
}], pub function_declarations: Option<Vec<FunctionDeclaration>>,
generation_config, }
tools: None,
} #[derive(Clone, Debug, Serialize, Deserialize, Default)]
} #[serde(rename_all = "camelCase")]
} pub struct GenerationConfig {
pub max_output_tokens: Option<i32>,
#[derive(Clone, Serialize, Deserialize)] pub temperature: Option<f32>,
pub struct Tools { pub top_p: Option<f32>,
pub function_declarations: Option<Vec<FunctionDeclaration>>, pub top_k: Option<i32>,
} pub stop_sequences: Option<Vec<String>>,
pub candidate_count: Option<u32>,
#[derive(Clone, Debug, Serialize, Deserialize)] }
pub struct Content {
pub role: String, #[derive(Debug, Serialize, Deserialize)]
pub parts: Option<Vec<Part>>, #[serde(rename_all = "camelCase")]
} pub struct Candidate {
pub content: Content,
impl Content { pub citation_metadata: Option<CitationMetadata>,
pub fn get_text(&self) -> Option<String> { pub safety_ratings: Vec<SafetyRating>,
self.parts.as_ref().map(|parts| { pub finish_reason: Option<String>,
parts }
.iter()
.filter_map(|part| match part { impl Candidate {
Part::Text(text) => Some(text.clone()), pub fn get_text(&self) -> Option<String> {
_ => None, self.content.get_text()
}) }
.collect::<String>() }
})
} #[derive(Debug, Serialize, Deserialize)]
} #[serde(rename_all = "camelCase")]
pub struct Citation {
#[derive(Clone, Debug, Serialize, Deserialize, Default)] pub start_index: i32,
#[serde(rename_all = "camelCase")] pub end_index: i32,
pub struct GenerationConfig { pub uri: Option<String>,
pub max_output_tokens: Option<i32>, }
pub temperature: Option<f32>,
pub top_p: Option<f32>, #[derive(Debug, Serialize, Deserialize)]
pub top_k: Option<i32>, pub struct CitationMetadata {
pub stop_sequences: Option<Vec<String>>, pub citations: Vec<Citation>,
pub candidate_count: Option<u32>, }
}
#[derive(Debug, Serialize, Deserialize)]
#[derive(Clone, Debug, Serialize, Deserialize)] pub struct SafetyRating {
#[serde(rename_all = "camelCase")] pub category: String,
pub enum Part { pub probability: String,
Text(String), }
InlineData {
mime_type: String, #[derive(Debug, Serialize, Deserialize)]
data: String, #[serde(rename_all = "camelCase")]
}, pub struct UsageMetadata {
FileData { pub candidates_token_count: Option<i32>,
mime_type: String, pub prompt_token_count: i32,
file_uri: String, pub total_token_count: i32,
}, }
FunctionCall {
name: String, #[derive(Clone, Debug, Serialize, Deserialize)]
args: HashMap<String, String>, #[serde(rename_all = "camelCase")]
}, pub struct FunctionDeclaration {
} pub name: String,
pub description: String,
#[derive(Debug, Serialize, Deserialize)] pub parameters: FunctionParameters,
#[serde(rename_all = "camelCase")] }
#[serde(untagged)]
pub enum GenerateContentResponse { #[derive(Clone, Debug, Serialize, Deserialize)]
Ok { #[serde(rename_all = "camelCase")]
candidates: Vec<Candidate>, pub struct FunctionParameters {
usage_metadata: Option<UsageMetadata>, pub r#type: String,
}, pub properties: HashMap<String, FunctionParametersProperty>,
Error { pub required: Vec<String>,
error: Error, }
},
} #[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
#[derive(Debug, Serialize, Deserialize)] pub struct FunctionParametersProperty {
#[serde(rename_all = "camelCase")] pub r#type: String,
pub struct Candidate { pub description: String,
pub content: Content, }
pub citation_metadata: Option<CitationMetadata>,
pub safety_ratings: Vec<SafetyRating>, #[derive(Debug, Serialize, Deserialize)]
pub finish_reason: Option<String>, #[serde(rename_all = "camelCase")]
} #[serde(untagged)]
pub enum GenerateContentResponse {
impl Candidate { Ok {
pub fn get_text(&self) -> Option<String> { candidates: Vec<Candidate>,
self.content.get_text() usage_metadata: Option<UsageMetadata>,
} },
} Error {
error: Error,
#[derive(Debug, Serialize, Deserialize)] },
pub struct SafetyRating { }
pub category: String,
pub probability: String,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct Citation {
pub start_index: i32,
pub end_index: i32,
pub uri: Option<String>,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct CitationMetadata {
pub citations: Vec<Citation>,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct UsageMetadata {
pub candidates_token_count: Option<i32>,
pub prompt_token_count: i32,
pub total_token_count: i32,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct FunctionDeclaration {
pub name: String,
pub description: String,
pub parameters: FunctionParameters,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct FunctionParameters {
pub r#type: String,
pub properties: HashMap<String, FunctionParametersProperty>,
pub required: Vec<String>,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct FunctionParametersProperty {
pub r#type: String,
pub description: String,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct Error {
pub code: i32,
pub message: String,
pub status: String,
pub details: Vec<ErrorType>,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct Link {
pub description: String,
pub url: String,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(tag = "@type")]
pub enum ErrorType {
#[serde(rename = "type.googleapis.com/google.rpc.ErrorInfo")]
ErrorInfo { metadata: ErrorInfoMetadata },
#[serde(rename = "type.googleapis.com/google.rpc.Help")]
Help { links: Vec<Link> },
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct ErrorInfoMetadata {
service: String,
consumer: String,
}

9
src/types/mod.rs Normal file
View File

@@ -0,0 +1,9 @@
mod common;
mod count_tokens;
mod error;
mod generate_content;
pub use common::*;
pub use count_tokens::*;
pub use error::*;
pub use generate_content::*;