From b8f2b7e85e77b0785e9b3e7f19f10b0238465123 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andr=C3=A9=20Cipriani=20Bandarra?= Date: Fri, 23 Feb 2024 18:25:32 +0000 Subject: [PATCH] Refactor types into separate files --- examples/text-from-text-streaming.rs | 4 +- src/client.rs | 20 +- src/types/common.rs | 41 +++ src/types/count_tokens.rs | 14 + src/types/error.rs | 31 ++ src/{types.rs => types/generate_content.rs} | 315 ++++++++------------ src/types/mod.rs | 9 + 7 files changed, 219 insertions(+), 215 deletions(-) create mode 100644 src/types/common.rs create mode 100644 src/types/count_tokens.rs create mode 100644 src/types/error.rs rename src/{types.rs => types/generate_content.rs} (60%) create mode 100644 src/types/mod.rs diff --git a/examples/text-from-text-streaming.rs b/examples/text-from-text-streaming.rs index ddd1b0f..b241bbf 100644 --- a/examples/text-from-text-streaming.rs +++ b/examples/text-from-text-streaming.rs @@ -20,9 +20,7 @@ async fn main() -> Result<(), Box> { let prompt = "Tell me the story of the genesis of the universe as a bedtime story."; let request = GenerateContentRequest::from_prompt(prompt, None); - let queue = gemini - .stream_generate_content(&request, Model::GeminiPro) - .await; + let queue = gemini.stream_generate_content(&request, "gemini-pro").await; while let Some(response) = queue.pop().await { if let GenerateContentResponse::Ok { diff --git a/src/client.rs b/src/client.rs index 6f1037d..2ea3782 100644 --- a/src/client.rs +++ b/src/client.rs @@ -13,18 +13,6 @@ use crate::{prelude::Part, token_provider::TokenProvider}; pub static AUTH_SCOPE: &[&str] = &["https://www.googleapis.com/auth/cloud-platform"]; -pub enum Model { - GeminiPro, -} - -impl ToString for Model { - fn to_string(&self) -> String { - match self { - Model::GeminiPro => "gemini-pro".to_string(), - } - } -} - #[derive(Clone, Debug)] pub struct GeminiClient { token_provider: T, @@ -56,7 +44,7 @@ impl GeminiClient { pub async fn stream_generate_content( &self, request: &GenerateContentRequest, - model: Model, + model: &str, ) -> Arc>> { let queue = Arc::new(Queue::>::new()); @@ -93,7 +81,7 @@ impl GeminiClient { pub async fn generate_content( &self, request: &GenerateContentRequest, - model: Model, + model: &str, ) -> Result { let access_token = self.token_provider.get_token(AUTH_SCOPE).await?; let endpoint_url: String = format!( @@ -131,7 +119,7 @@ impl GeminiClient { tools: None, }; - let response = self.generate_content(&request, Model::GeminiPro).await?; + let response = self.generate_content(&request, "gemini-pro").await?; // Check for errors in the response. let mut candidates = GeminiClient::::collect_text_from_response(&response)?; @@ -158,7 +146,7 @@ impl GeminiClient { tools: None, }; - let response = self.generate_content(&request, Model::GeminiPro).await?; + let response = self.generate_content(&request, "gemini-pro").await?; let mut candidates = GeminiClient::::collect_text_from_response(&response)?; match candidates.pop() { diff --git a/src/types/common.rs b/src/types/common.rs new file mode 100644 index 0000000..1d6a43b --- /dev/null +++ b/src/types/common.rs @@ -0,0 +1,41 @@ +use std::collections::HashMap; + +use serde::{Deserialize, Serialize}; + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct Content { + pub role: String, + pub parts: Option>, +} + +impl Content { + pub fn get_text(&self) -> Option { + self.parts.as_ref().map(|parts| { + parts + .iter() + .filter_map(|part| match part { + Part::Text(text) => Some(text.clone()), + _ => None, + }) + .collect::() + }) + } +} + +#[derive(Clone, Debug, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub enum Part { + Text(String), + InlineData { + mime_type: String, + data: String, + }, + FileData { + mime_type: String, + file_uri: String, + }, + FunctionCall { + name: String, + args: HashMap, + }, +} diff --git a/src/types/count_tokens.rs b/src/types/count_tokens.rs new file mode 100644 index 0000000..f3036e5 --- /dev/null +++ b/src/types/count_tokens.rs @@ -0,0 +1,14 @@ +use serde::{Deserialize, Serialize}; + +use super::Content; + +#[derive(Debug, Serialize, Deserialize)] +pub struct CountTokensRequest { + pub contents: Content, +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct CountTokensResponse { + pub total_tokens: i32, +} diff --git a/src/types/error.rs b/src/types/error.rs new file mode 100644 index 0000000..4225230 --- /dev/null +++ b/src/types/error.rs @@ -0,0 +1,31 @@ +use serde::{Deserialize, Serialize}; + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct Error { + pub code: i32, + pub message: String, + pub status: String, + pub details: Vec, +} + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct Link { + pub description: String, + pub url: String, +} + +#[derive(Clone, Debug, Serialize, Deserialize)] +#[serde(tag = "@type")] +pub enum ErrorType { + #[serde(rename = "type.googleapis.com/google.rpc.ErrorInfo")] + ErrorInfo { metadata: ErrorInfoMetadata }, + + #[serde(rename = "type.googleapis.com/google.rpc.Help")] + Help { links: Vec }, +} + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct ErrorInfoMetadata { + service: String, + consumer: String, +} diff --git a/src/types.rs b/src/types/generate_content.rs similarity index 60% rename from src/types.rs rename to src/types/generate_content.rs index f40dd7b..bbd3f8a 100644 --- a/src/types.rs +++ b/src/types/generate_content.rs @@ -1,196 +1,119 @@ -use std::collections::HashMap; - -use serde::{Deserialize, Serialize}; - -#[derive(Debug, Serialize, Deserialize)] -pub struct CountTokensRequest { - pub contents: Content, -} - -#[derive(Debug, Serialize, Deserialize)] -#[serde(rename_all = "camelCase")] -pub struct CountTokensResponse { - pub total_tokens: i32, -} - -#[derive(Clone, Serialize, Deserialize)] -pub struct GenerateContentRequest { - pub contents: Vec, - pub generation_config: Option, - pub tools: Option>, -} - -impl GenerateContentRequest { - pub fn from_prompt(prompt: &str, generation_config: Option) -> Self { - GenerateContentRequest { - contents: vec![Content { - role: "user".to_string(), - parts: Some(vec![Part::Text(prompt.to_string())]), - }], - generation_config, - tools: None, - } - } -} - -#[derive(Clone, Serialize, Deserialize)] -pub struct Tools { - pub function_declarations: Option>, -} - -#[derive(Clone, Debug, Serialize, Deserialize)] -pub struct Content { - pub role: String, - pub parts: Option>, -} - -impl Content { - pub fn get_text(&self) -> Option { - self.parts.as_ref().map(|parts| { - parts - .iter() - .filter_map(|part| match part { - Part::Text(text) => Some(text.clone()), - _ => None, - }) - .collect::() - }) - } -} - -#[derive(Clone, Debug, Serialize, Deserialize, Default)] -#[serde(rename_all = "camelCase")] -pub struct GenerationConfig { - pub max_output_tokens: Option, - pub temperature: Option, - pub top_p: Option, - pub top_k: Option, - pub stop_sequences: Option>, - pub candidate_count: Option, -} - -#[derive(Clone, Debug, Serialize, Deserialize)] -#[serde(rename_all = "camelCase")] -pub enum Part { - Text(String), - InlineData { - mime_type: String, - data: String, - }, - FileData { - mime_type: String, - file_uri: String, - }, - FunctionCall { - name: String, - args: HashMap, - }, -} - -#[derive(Debug, Serialize, Deserialize)] -#[serde(rename_all = "camelCase")] -#[serde(untagged)] -pub enum GenerateContentResponse { - Ok { - candidates: Vec, - usage_metadata: Option, - }, - Error { - error: Error, - }, -} - -#[derive(Debug, Serialize, Deserialize)] -#[serde(rename_all = "camelCase")] -pub struct Candidate { - pub content: Content, - pub citation_metadata: Option, - pub safety_ratings: Vec, - pub finish_reason: Option, -} - -impl Candidate { - pub fn get_text(&self) -> Option { - self.content.get_text() - } -} - -#[derive(Debug, Serialize, Deserialize)] -pub struct SafetyRating { - pub category: String, - pub probability: String, -} - -#[derive(Debug, Serialize, Deserialize)] -#[serde(rename_all = "camelCase")] -pub struct Citation { - pub start_index: i32, - pub end_index: i32, - pub uri: Option, -} - -#[derive(Debug, Serialize, Deserialize)] -pub struct CitationMetadata { - pub citations: Vec, -} - -#[derive(Debug, Serialize, Deserialize)] -#[serde(rename_all = "camelCase")] -pub struct UsageMetadata { - pub candidates_token_count: Option, - pub prompt_token_count: i32, - pub total_token_count: i32, -} - -#[derive(Clone, Debug, Serialize, Deserialize)] -#[serde(rename_all = "camelCase")] -pub struct FunctionDeclaration { - pub name: String, - pub description: String, - pub parameters: FunctionParameters, -} - -#[derive(Clone, Debug, Serialize, Deserialize)] -#[serde(rename_all = "camelCase")] -pub struct FunctionParameters { - pub r#type: String, - pub properties: HashMap, - pub required: Vec, -} - -#[derive(Clone, Debug, Serialize, Deserialize)] -#[serde(rename_all = "camelCase")] -pub struct FunctionParametersProperty { - pub r#type: String, - pub description: String, -} - -#[derive(Clone, Debug, Serialize, Deserialize)] -pub struct Error { - pub code: i32, - pub message: String, - pub status: String, - pub details: Vec, -} - -#[derive(Clone, Debug, Serialize, Deserialize)] -pub struct Link { - pub description: String, - pub url: String, -} - -#[derive(Clone, Debug, Serialize, Deserialize)] -#[serde(tag = "@type")] -pub enum ErrorType { - #[serde(rename = "type.googleapis.com/google.rpc.ErrorInfo")] - ErrorInfo { metadata: ErrorInfoMetadata }, - - #[serde(rename = "type.googleapis.com/google.rpc.Help")] - Help { links: Vec }, -} - -#[derive(Clone, Debug, Serialize, Deserialize)] -pub struct ErrorInfoMetadata { - service: String, - consumer: String, -} +use std::collections::HashMap; + +use serde::{Deserialize, Serialize}; + +use super::{Content, Error, Part}; + +#[derive(Clone, Serialize, Deserialize)] +pub struct GenerateContentRequest { + pub contents: Vec, + pub generation_config: Option, + pub tools: Option>, +} + +impl GenerateContentRequest { + pub fn from_prompt(prompt: &str, generation_config: Option) -> Self { + GenerateContentRequest { + contents: vec![Content { + role: "user".to_string(), + parts: Some(vec![Part::Text(prompt.to_string())]), + }], + generation_config, + tools: None, + } + } +} + +#[derive(Clone, Serialize, Deserialize)] +pub struct Tools { + pub function_declarations: Option>, +} + +#[derive(Clone, Debug, Serialize, Deserialize, Default)] +#[serde(rename_all = "camelCase")] +pub struct GenerationConfig { + pub max_output_tokens: Option, + pub temperature: Option, + pub top_p: Option, + pub top_k: Option, + pub stop_sequences: Option>, + pub candidate_count: Option, +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct Candidate { + pub content: Content, + pub citation_metadata: Option, + pub safety_ratings: Vec, + pub finish_reason: Option, +} + +impl Candidate { + pub fn get_text(&self) -> Option { + self.content.get_text() + } +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct Citation { + pub start_index: i32, + pub end_index: i32, + pub uri: Option, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct CitationMetadata { + pub citations: Vec, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct SafetyRating { + pub category: String, + pub probability: String, +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct UsageMetadata { + pub candidates_token_count: Option, + pub prompt_token_count: i32, + pub total_token_count: i32, +} + +#[derive(Clone, Debug, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct FunctionDeclaration { + pub name: String, + pub description: String, + pub parameters: FunctionParameters, +} + +#[derive(Clone, Debug, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct FunctionParameters { + pub r#type: String, + pub properties: HashMap, + pub required: Vec, +} + +#[derive(Clone, Debug, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct FunctionParametersProperty { + pub r#type: String, + pub description: String, +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +#[serde(untagged)] +pub enum GenerateContentResponse { + Ok { + candidates: Vec, + usage_metadata: Option, + }, + Error { + error: Error, + }, +} diff --git a/src/types/mod.rs b/src/types/mod.rs new file mode 100644 index 0000000..406384c --- /dev/null +++ b/src/types/mod.rs @@ -0,0 +1,9 @@ +mod common; +mod count_tokens; +mod error; +mod generate_content; + +pub use common::*; +pub use count_tokens::*; +pub use error::*; +pub use generate_content::*;