Refactor types into separate files

This commit is contained in:
2024-02-23 18:25:32 +00:00
parent 93285e53dd
commit b8f2b7e85e
7 changed files with 219 additions and 215 deletions

View File

@@ -13,18 +13,6 @@ use crate::{prelude::Part, token_provider::TokenProvider};
pub static AUTH_SCOPE: &[&str] = &["https://www.googleapis.com/auth/cloud-platform"];
pub enum Model {
GeminiPro,
}
impl ToString for Model {
fn to_string(&self) -> String {
match self {
Model::GeminiPro => "gemini-pro".to_string(),
}
}
}
#[derive(Clone, Debug)]
pub struct GeminiClient<T: TokenProvider + Clone> {
token_provider: T,
@@ -56,7 +44,7 @@ impl<T: TokenProvider + Clone> GeminiClient<T> {
pub async fn stream_generate_content(
&self,
request: &GenerateContentRequest,
model: Model,
model: &str,
) -> Arc<Queue<Option<GenerateContentResponse>>> {
let queue = Arc::new(Queue::<Option<GenerateContentResponse>>::new());
@@ -93,7 +81,7 @@ impl<T: TokenProvider + Clone> GeminiClient<T> {
pub async fn generate_content(
&self,
request: &GenerateContentRequest,
model: Model,
model: &str,
) -> Result<GenerateContentResponse> {
let access_token = self.token_provider.get_token(AUTH_SCOPE).await?;
let endpoint_url: String = format!(
@@ -131,7 +119,7 @@ impl<T: TokenProvider + Clone> GeminiClient<T> {
tools: None,
};
let response = self.generate_content(&request, Model::GeminiPro).await?;
let response = self.generate_content(&request, "gemini-pro").await?;
// Check for errors in the response.
let mut candidates = GeminiClient::<T>::collect_text_from_response(&response)?;
@@ -158,7 +146,7 @@ impl<T: TokenProvider + Clone> GeminiClient<T> {
tools: None,
};
let response = self.generate_content(&request, Model::GeminiPro).await?;
let response = self.generate_content(&request, "gemini-pro").await?;
let mut candidates = GeminiClient::<T>::collect_text_from_response(&response)?;
match candidates.pop() {

41
src/types/common.rs Normal file
View File

@@ -0,0 +1,41 @@
use std::collections::HashMap;
use serde::{Deserialize, Serialize};
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct Content {
pub role: String,
pub parts: Option<Vec<Part>>,
}
impl Content {
pub fn get_text(&self) -> Option<String> {
self.parts.as_ref().map(|parts| {
parts
.iter()
.filter_map(|part| match part {
Part::Text(text) => Some(text.clone()),
_ => None,
})
.collect::<String>()
})
}
}
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub enum Part {
Text(String),
InlineData {
mime_type: String,
data: String,
},
FileData {
mime_type: String,
file_uri: String,
},
FunctionCall {
name: String,
args: HashMap<String, String>,
},
}

14
src/types/count_tokens.rs Normal file
View File

@@ -0,0 +1,14 @@
use serde::{Deserialize, Serialize};
use super::Content;
#[derive(Debug, Serialize, Deserialize)]
pub struct CountTokensRequest {
pub contents: Content,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct CountTokensResponse {
pub total_tokens: i32,
}

31
src/types/error.rs Normal file
View File

@@ -0,0 +1,31 @@
use serde::{Deserialize, Serialize};
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct Error {
pub code: i32,
pub message: String,
pub status: String,
pub details: Vec<ErrorType>,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct Link {
pub description: String,
pub url: String,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(tag = "@type")]
pub enum ErrorType {
#[serde(rename = "type.googleapis.com/google.rpc.ErrorInfo")]
ErrorInfo { metadata: ErrorInfoMetadata },
#[serde(rename = "type.googleapis.com/google.rpc.Help")]
Help { links: Vec<Link> },
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct ErrorInfoMetadata {
service: String,
consumer: String,
}

View File

@@ -1,196 +1,119 @@
use std::collections::HashMap;
use serde::{Deserialize, Serialize};
#[derive(Debug, Serialize, Deserialize)]
pub struct CountTokensRequest {
pub contents: Content,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct CountTokensResponse {
pub total_tokens: i32,
}
#[derive(Clone, Serialize, Deserialize)]
pub struct GenerateContentRequest {
pub contents: Vec<Content>,
pub generation_config: Option<GenerationConfig>,
pub tools: Option<Vec<Tools>>,
}
impl GenerateContentRequest {
pub fn from_prompt(prompt: &str, generation_config: Option<GenerationConfig>) -> Self {
GenerateContentRequest {
contents: vec![Content {
role: "user".to_string(),
parts: Some(vec![Part::Text(prompt.to_string())]),
}],
generation_config,
tools: None,
}
}
}
#[derive(Clone, Serialize, Deserialize)]
pub struct Tools {
pub function_declarations: Option<Vec<FunctionDeclaration>>,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct Content {
pub role: String,
pub parts: Option<Vec<Part>>,
}
impl Content {
pub fn get_text(&self) -> Option<String> {
self.parts.as_ref().map(|parts| {
parts
.iter()
.filter_map(|part| match part {
Part::Text(text) => Some(text.clone()),
_ => None,
})
.collect::<String>()
})
}
}
#[derive(Clone, Debug, Serialize, Deserialize, Default)]
#[serde(rename_all = "camelCase")]
pub struct GenerationConfig {
pub max_output_tokens: Option<i32>,
pub temperature: Option<f32>,
pub top_p: Option<f32>,
pub top_k: Option<i32>,
pub stop_sequences: Option<Vec<String>>,
pub candidate_count: Option<u32>,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub enum Part {
Text(String),
InlineData {
mime_type: String,
data: String,
},
FileData {
mime_type: String,
file_uri: String,
},
FunctionCall {
name: String,
args: HashMap<String, String>,
},
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
#[serde(untagged)]
pub enum GenerateContentResponse {
Ok {
candidates: Vec<Candidate>,
usage_metadata: Option<UsageMetadata>,
},
Error {
error: Error,
},
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct Candidate {
pub content: Content,
pub citation_metadata: Option<CitationMetadata>,
pub safety_ratings: Vec<SafetyRating>,
pub finish_reason: Option<String>,
}
impl Candidate {
pub fn get_text(&self) -> Option<String> {
self.content.get_text()
}
}
#[derive(Debug, Serialize, Deserialize)]
pub struct SafetyRating {
pub category: String,
pub probability: String,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct Citation {
pub start_index: i32,
pub end_index: i32,
pub uri: Option<String>,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct CitationMetadata {
pub citations: Vec<Citation>,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct UsageMetadata {
pub candidates_token_count: Option<i32>,
pub prompt_token_count: i32,
pub total_token_count: i32,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct FunctionDeclaration {
pub name: String,
pub description: String,
pub parameters: FunctionParameters,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct FunctionParameters {
pub r#type: String,
pub properties: HashMap<String, FunctionParametersProperty>,
pub required: Vec<String>,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct FunctionParametersProperty {
pub r#type: String,
pub description: String,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct Error {
pub code: i32,
pub message: String,
pub status: String,
pub details: Vec<ErrorType>,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct Link {
pub description: String,
pub url: String,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(tag = "@type")]
pub enum ErrorType {
#[serde(rename = "type.googleapis.com/google.rpc.ErrorInfo")]
ErrorInfo { metadata: ErrorInfoMetadata },
#[serde(rename = "type.googleapis.com/google.rpc.Help")]
Help { links: Vec<Link> },
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct ErrorInfoMetadata {
service: String,
consumer: String,
}
use std::collections::HashMap;
use serde::{Deserialize, Serialize};
use super::{Content, Error, Part};
#[derive(Clone, Serialize, Deserialize)]
pub struct GenerateContentRequest {
pub contents: Vec<Content>,
pub generation_config: Option<GenerationConfig>,
pub tools: Option<Vec<Tools>>,
}
impl GenerateContentRequest {
pub fn from_prompt(prompt: &str, generation_config: Option<GenerationConfig>) -> Self {
GenerateContentRequest {
contents: vec![Content {
role: "user".to_string(),
parts: Some(vec![Part::Text(prompt.to_string())]),
}],
generation_config,
tools: None,
}
}
}
#[derive(Clone, Serialize, Deserialize)]
pub struct Tools {
pub function_declarations: Option<Vec<FunctionDeclaration>>,
}
#[derive(Clone, Debug, Serialize, Deserialize, Default)]
#[serde(rename_all = "camelCase")]
pub struct GenerationConfig {
pub max_output_tokens: Option<i32>,
pub temperature: Option<f32>,
pub top_p: Option<f32>,
pub top_k: Option<i32>,
pub stop_sequences: Option<Vec<String>>,
pub candidate_count: Option<u32>,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct Candidate {
pub content: Content,
pub citation_metadata: Option<CitationMetadata>,
pub safety_ratings: Vec<SafetyRating>,
pub finish_reason: Option<String>,
}
impl Candidate {
pub fn get_text(&self) -> Option<String> {
self.content.get_text()
}
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct Citation {
pub start_index: i32,
pub end_index: i32,
pub uri: Option<String>,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct CitationMetadata {
pub citations: Vec<Citation>,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct SafetyRating {
pub category: String,
pub probability: String,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct UsageMetadata {
pub candidates_token_count: Option<i32>,
pub prompt_token_count: i32,
pub total_token_count: i32,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct FunctionDeclaration {
pub name: String,
pub description: String,
pub parameters: FunctionParameters,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct FunctionParameters {
pub r#type: String,
pub properties: HashMap<String, FunctionParametersProperty>,
pub required: Vec<String>,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct FunctionParametersProperty {
pub r#type: String,
pub description: String,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
#[serde(untagged)]
pub enum GenerateContentResponse {
Ok {
candidates: Vec<Candidate>,
usage_metadata: Option<UsageMetadata>,
},
Error {
error: Error,
},
}

9
src/types/mod.rs Normal file
View File

@@ -0,0 +1,9 @@
mod common;
mod count_tokens;
mod error;
mod generate_content;
pub use common::*;
pub use count_tokens::*;
pub use error::*;
pub use generate_content::*;