Refactor types into separate files

This commit is contained in:
2024-02-23 18:25:32 +00:00
parent 93285e53dd
commit b8f2b7e85e
7 changed files with 219 additions and 215 deletions

View File

@@ -20,9 +20,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
let prompt = "Tell me the story of the genesis of the universe as a bedtime story."; let prompt = "Tell me the story of the genesis of the universe as a bedtime story.";
let request = GenerateContentRequest::from_prompt(prompt, None); let request = GenerateContentRequest::from_prompt(prompt, None);
let queue = gemini let queue = gemini.stream_generate_content(&request, "gemini-pro").await;
.stream_generate_content(&request, Model::GeminiPro)
.await;
while let Some(response) = queue.pop().await { while let Some(response) = queue.pop().await {
if let GenerateContentResponse::Ok { if let GenerateContentResponse::Ok {

View File

@@ -13,18 +13,6 @@ use crate::{prelude::Part, token_provider::TokenProvider};
pub static AUTH_SCOPE: &[&str] = &["https://www.googleapis.com/auth/cloud-platform"]; pub static AUTH_SCOPE: &[&str] = &["https://www.googleapis.com/auth/cloud-platform"];
pub enum Model {
GeminiPro,
}
impl ToString for Model {
fn to_string(&self) -> String {
match self {
Model::GeminiPro => "gemini-pro".to_string(),
}
}
}
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
pub struct GeminiClient<T: TokenProvider + Clone> { pub struct GeminiClient<T: TokenProvider + Clone> {
token_provider: T, token_provider: T,
@@ -56,7 +44,7 @@ impl<T: TokenProvider + Clone> GeminiClient<T> {
pub async fn stream_generate_content( pub async fn stream_generate_content(
&self, &self,
request: &GenerateContentRequest, request: &GenerateContentRequest,
model: Model, model: &str,
) -> Arc<Queue<Option<GenerateContentResponse>>> { ) -> Arc<Queue<Option<GenerateContentResponse>>> {
let queue = Arc::new(Queue::<Option<GenerateContentResponse>>::new()); let queue = Arc::new(Queue::<Option<GenerateContentResponse>>::new());
@@ -93,7 +81,7 @@ impl<T: TokenProvider + Clone> GeminiClient<T> {
pub async fn generate_content( pub async fn generate_content(
&self, &self,
request: &GenerateContentRequest, request: &GenerateContentRequest,
model: Model, model: &str,
) -> Result<GenerateContentResponse> { ) -> Result<GenerateContentResponse> {
let access_token = self.token_provider.get_token(AUTH_SCOPE).await?; let access_token = self.token_provider.get_token(AUTH_SCOPE).await?;
let endpoint_url: String = format!( let endpoint_url: String = format!(
@@ -131,7 +119,7 @@ impl<T: TokenProvider + Clone> GeminiClient<T> {
tools: None, tools: None,
}; };
let response = self.generate_content(&request, Model::GeminiPro).await?; let response = self.generate_content(&request, "gemini-pro").await?;
// Check for errors in the response. // Check for errors in the response.
let mut candidates = GeminiClient::<T>::collect_text_from_response(&response)?; let mut candidates = GeminiClient::<T>::collect_text_from_response(&response)?;
@@ -158,7 +146,7 @@ impl<T: TokenProvider + Clone> GeminiClient<T> {
tools: None, tools: None,
}; };
let response = self.generate_content(&request, Model::GeminiPro).await?; let response = self.generate_content(&request, "gemini-pro").await?;
let mut candidates = GeminiClient::<T>::collect_text_from_response(&response)?; let mut candidates = GeminiClient::<T>::collect_text_from_response(&response)?;
match candidates.pop() { match candidates.pop() {

41
src/types/common.rs Normal file
View File

@@ -0,0 +1,41 @@
use std::collections::HashMap;
use serde::{Deserialize, Serialize};
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct Content {
pub role: String,
pub parts: Option<Vec<Part>>,
}
impl Content {
pub fn get_text(&self) -> Option<String> {
self.parts.as_ref().map(|parts| {
parts
.iter()
.filter_map(|part| match part {
Part::Text(text) => Some(text.clone()),
_ => None,
})
.collect::<String>()
})
}
}
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub enum Part {
Text(String),
InlineData {
mime_type: String,
data: String,
},
FileData {
mime_type: String,
file_uri: String,
},
FunctionCall {
name: String,
args: HashMap<String, String>,
},
}

14
src/types/count_tokens.rs Normal file
View File

@@ -0,0 +1,14 @@
use serde::{Deserialize, Serialize};
use super::Content;
#[derive(Debug, Serialize, Deserialize)]
pub struct CountTokensRequest {
pub contents: Content,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct CountTokensResponse {
pub total_tokens: i32,
}

31
src/types/error.rs Normal file
View File

@@ -0,0 +1,31 @@
use serde::{Deserialize, Serialize};
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct Error {
pub code: i32,
pub message: String,
pub status: String,
pub details: Vec<ErrorType>,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct Link {
pub description: String,
pub url: String,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(tag = "@type")]
pub enum ErrorType {
#[serde(rename = "type.googleapis.com/google.rpc.ErrorInfo")]
ErrorInfo { metadata: ErrorInfoMetadata },
#[serde(rename = "type.googleapis.com/google.rpc.Help")]
Help { links: Vec<Link> },
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct ErrorInfoMetadata {
service: String,
consumer: String,
}

View File

@@ -2,16 +2,7 @@ use std::collections::HashMap;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
#[derive(Debug, Serialize, Deserialize)] use super::{Content, Error, Part};
pub struct CountTokensRequest {
pub contents: Content,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct CountTokensResponse {
pub total_tokens: i32,
}
#[derive(Clone, Serialize, Deserialize)] #[derive(Clone, Serialize, Deserialize)]
pub struct GenerateContentRequest { pub struct GenerateContentRequest {
@@ -38,26 +29,6 @@ pub struct Tools {
pub function_declarations: Option<Vec<FunctionDeclaration>>, pub function_declarations: Option<Vec<FunctionDeclaration>>,
} }
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct Content {
pub role: String,
pub parts: Option<Vec<Part>>,
}
impl Content {
pub fn get_text(&self) -> Option<String> {
self.parts.as_ref().map(|parts| {
parts
.iter()
.filter_map(|part| match part {
Part::Text(text) => Some(text.clone()),
_ => None,
})
.collect::<String>()
})
}
}
#[derive(Clone, Debug, Serialize, Deserialize, Default)] #[derive(Clone, Debug, Serialize, Deserialize, Default)]
#[serde(rename_all = "camelCase")] #[serde(rename_all = "camelCase")]
pub struct GenerationConfig { pub struct GenerationConfig {
@@ -69,37 +40,6 @@ pub struct GenerationConfig {
pub candidate_count: Option<u32>, pub candidate_count: Option<u32>,
} }
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub enum Part {
Text(String),
InlineData {
mime_type: String,
data: String,
},
FileData {
mime_type: String,
file_uri: String,
},
FunctionCall {
name: String,
args: HashMap<String, String>,
},
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
#[serde(untagged)]
pub enum GenerateContentResponse {
Ok {
candidates: Vec<Candidate>,
usage_metadata: Option<UsageMetadata>,
},
Error {
error: Error,
},
}
#[derive(Debug, Serialize, Deserialize)] #[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")] #[serde(rename_all = "camelCase")]
pub struct Candidate { pub struct Candidate {
@@ -115,12 +55,6 @@ impl Candidate {
} }
} }
#[derive(Debug, Serialize, Deserialize)]
pub struct SafetyRating {
pub category: String,
pub probability: String,
}
#[derive(Debug, Serialize, Deserialize)] #[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")] #[serde(rename_all = "camelCase")]
pub struct Citation { pub struct Citation {
@@ -134,6 +68,12 @@ pub struct CitationMetadata {
pub citations: Vec<Citation>, pub citations: Vec<Citation>,
} }
#[derive(Debug, Serialize, Deserialize)]
pub struct SafetyRating {
pub category: String,
pub probability: String,
}
#[derive(Debug, Serialize, Deserialize)] #[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")] #[serde(rename_all = "camelCase")]
pub struct UsageMetadata { pub struct UsageMetadata {
@@ -165,32 +105,15 @@ pub struct FunctionParametersProperty {
pub description: String, pub description: String,
} }
#[derive(Clone, Debug, Serialize, Deserialize)] #[derive(Debug, Serialize, Deserialize)]
pub struct Error { #[serde(rename_all = "camelCase")]
pub code: i32, #[serde(untagged)]
pub message: String, pub enum GenerateContentResponse {
pub status: String, Ok {
pub details: Vec<ErrorType>, candidates: Vec<Candidate>,
} usage_metadata: Option<UsageMetadata>,
},
#[derive(Clone, Debug, Serialize, Deserialize)] Error {
pub struct Link { error: Error,
pub description: String, },
pub url: String,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(tag = "@type")]
pub enum ErrorType {
#[serde(rename = "type.googleapis.com/google.rpc.ErrorInfo")]
ErrorInfo { metadata: ErrorInfoMetadata },
#[serde(rename = "type.googleapis.com/google.rpc.Help")]
Help { links: Vec<Link> },
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct ErrorInfoMetadata {
service: String,
consumer: String,
} }

9
src/types/mod.rs Normal file
View File

@@ -0,0 +1,9 @@
mod common;
mod count_tokens;
mod error;
mod generate_content;
pub use common::*;
pub use count_tokens::*;
pub use error::*;
pub use generate_content::*;