Refactor types into separate files
This commit is contained in:
@@ -20,9 +20,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
|
||||
let prompt = "Tell me the story of the genesis of the universe as a bedtime story.";
|
||||
let request = GenerateContentRequest::from_prompt(prompt, None);
|
||||
let queue = gemini
|
||||
.stream_generate_content(&request, Model::GeminiPro)
|
||||
.await;
|
||||
let queue = gemini.stream_generate_content(&request, "gemini-pro").await;
|
||||
|
||||
while let Some(response) = queue.pop().await {
|
||||
if let GenerateContentResponse::Ok {
|
||||
|
||||
@@ -13,18 +13,6 @@ use crate::{prelude::Part, token_provider::TokenProvider};
|
||||
|
||||
pub static AUTH_SCOPE: &[&str] = &["https://www.googleapis.com/auth/cloud-platform"];
|
||||
|
||||
pub enum Model {
|
||||
GeminiPro,
|
||||
}
|
||||
|
||||
impl ToString for Model {
|
||||
fn to_string(&self) -> String {
|
||||
match self {
|
||||
Model::GeminiPro => "gemini-pro".to_string(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct GeminiClient<T: TokenProvider + Clone> {
|
||||
token_provider: T,
|
||||
@@ -56,7 +44,7 @@ impl<T: TokenProvider + Clone> GeminiClient<T> {
|
||||
pub async fn stream_generate_content(
|
||||
&self,
|
||||
request: &GenerateContentRequest,
|
||||
model: Model,
|
||||
model: &str,
|
||||
) -> Arc<Queue<Option<GenerateContentResponse>>> {
|
||||
let queue = Arc::new(Queue::<Option<GenerateContentResponse>>::new());
|
||||
|
||||
@@ -93,7 +81,7 @@ impl<T: TokenProvider + Clone> GeminiClient<T> {
|
||||
pub async fn generate_content(
|
||||
&self,
|
||||
request: &GenerateContentRequest,
|
||||
model: Model,
|
||||
model: &str,
|
||||
) -> Result<GenerateContentResponse> {
|
||||
let access_token = self.token_provider.get_token(AUTH_SCOPE).await?;
|
||||
let endpoint_url: String = format!(
|
||||
@@ -131,7 +119,7 @@ impl<T: TokenProvider + Clone> GeminiClient<T> {
|
||||
tools: None,
|
||||
};
|
||||
|
||||
let response = self.generate_content(&request, Model::GeminiPro).await?;
|
||||
let response = self.generate_content(&request, "gemini-pro").await?;
|
||||
|
||||
// Check for errors in the response.
|
||||
let mut candidates = GeminiClient::<T>::collect_text_from_response(&response)?;
|
||||
@@ -158,7 +146,7 @@ impl<T: TokenProvider + Clone> GeminiClient<T> {
|
||||
tools: None,
|
||||
};
|
||||
|
||||
let response = self.generate_content(&request, Model::GeminiPro).await?;
|
||||
let response = self.generate_content(&request, "gemini-pro").await?;
|
||||
let mut candidates = GeminiClient::<T>::collect_text_from_response(&response)?;
|
||||
|
||||
match candidates.pop() {
|
||||
|
||||
41
src/types/common.rs
Normal file
41
src/types/common.rs
Normal file
@@ -0,0 +1,41 @@
|
||||
use std::collections::HashMap;
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct Content {
|
||||
pub role: String,
|
||||
pub parts: Option<Vec<Part>>,
|
||||
}
|
||||
|
||||
impl Content {
|
||||
pub fn get_text(&self) -> Option<String> {
|
||||
self.parts.as_ref().map(|parts| {
|
||||
parts
|
||||
.iter()
|
||||
.filter_map(|part| match part {
|
||||
Part::Text(text) => Some(text.clone()),
|
||||
_ => None,
|
||||
})
|
||||
.collect::<String>()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub enum Part {
|
||||
Text(String),
|
||||
InlineData {
|
||||
mime_type: String,
|
||||
data: String,
|
||||
},
|
||||
FileData {
|
||||
mime_type: String,
|
||||
file_uri: String,
|
||||
},
|
||||
FunctionCall {
|
||||
name: String,
|
||||
args: HashMap<String, String>,
|
||||
},
|
||||
}
|
||||
14
src/types/count_tokens.rs
Normal file
14
src/types/count_tokens.rs
Normal file
@@ -0,0 +1,14 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use super::Content;
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct CountTokensRequest {
|
||||
pub contents: Content,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct CountTokensResponse {
|
||||
pub total_tokens: i32,
|
||||
}
|
||||
31
src/types/error.rs
Normal file
31
src/types/error.rs
Normal file
@@ -0,0 +1,31 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct Error {
|
||||
pub code: i32,
|
||||
pub message: String,
|
||||
pub status: String,
|
||||
pub details: Vec<ErrorType>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct Link {
|
||||
pub description: String,
|
||||
pub url: String,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
#[serde(tag = "@type")]
|
||||
pub enum ErrorType {
|
||||
#[serde(rename = "type.googleapis.com/google.rpc.ErrorInfo")]
|
||||
ErrorInfo { metadata: ErrorInfoMetadata },
|
||||
|
||||
#[serde(rename = "type.googleapis.com/google.rpc.Help")]
|
||||
Help { links: Vec<Link> },
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct ErrorInfoMetadata {
|
||||
service: String,
|
||||
consumer: String,
|
||||
}
|
||||
@@ -1,196 +1,119 @@
|
||||
use std::collections::HashMap;
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct CountTokensRequest {
|
||||
pub contents: Content,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct CountTokensResponse {
|
||||
pub total_tokens: i32,
|
||||
}
|
||||
|
||||
#[derive(Clone, Serialize, Deserialize)]
|
||||
pub struct GenerateContentRequest {
|
||||
pub contents: Vec<Content>,
|
||||
pub generation_config: Option<GenerationConfig>,
|
||||
pub tools: Option<Vec<Tools>>,
|
||||
}
|
||||
|
||||
impl GenerateContentRequest {
|
||||
pub fn from_prompt(prompt: &str, generation_config: Option<GenerationConfig>) -> Self {
|
||||
GenerateContentRequest {
|
||||
contents: vec![Content {
|
||||
role: "user".to_string(),
|
||||
parts: Some(vec![Part::Text(prompt.to_string())]),
|
||||
}],
|
||||
generation_config,
|
||||
tools: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Serialize, Deserialize)]
|
||||
pub struct Tools {
|
||||
pub function_declarations: Option<Vec<FunctionDeclaration>>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct Content {
|
||||
pub role: String,
|
||||
pub parts: Option<Vec<Part>>,
|
||||
}
|
||||
|
||||
impl Content {
|
||||
pub fn get_text(&self) -> Option<String> {
|
||||
self.parts.as_ref().map(|parts| {
|
||||
parts
|
||||
.iter()
|
||||
.filter_map(|part| match part {
|
||||
Part::Text(text) => Some(text.clone()),
|
||||
_ => None,
|
||||
})
|
||||
.collect::<String>()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize, Default)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct GenerationConfig {
|
||||
pub max_output_tokens: Option<i32>,
|
||||
pub temperature: Option<f32>,
|
||||
pub top_p: Option<f32>,
|
||||
pub top_k: Option<i32>,
|
||||
pub stop_sequences: Option<Vec<String>>,
|
||||
pub candidate_count: Option<u32>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub enum Part {
|
||||
Text(String),
|
||||
InlineData {
|
||||
mime_type: String,
|
||||
data: String,
|
||||
},
|
||||
FileData {
|
||||
mime_type: String,
|
||||
file_uri: String,
|
||||
},
|
||||
FunctionCall {
|
||||
name: String,
|
||||
args: HashMap<String, String>,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
#[serde(untagged)]
|
||||
pub enum GenerateContentResponse {
|
||||
Ok {
|
||||
candidates: Vec<Candidate>,
|
||||
usage_metadata: Option<UsageMetadata>,
|
||||
},
|
||||
Error {
|
||||
error: Error,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct Candidate {
|
||||
pub content: Content,
|
||||
pub citation_metadata: Option<CitationMetadata>,
|
||||
pub safety_ratings: Vec<SafetyRating>,
|
||||
pub finish_reason: Option<String>,
|
||||
}
|
||||
|
||||
impl Candidate {
|
||||
pub fn get_text(&self) -> Option<String> {
|
||||
self.content.get_text()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct SafetyRating {
|
||||
pub category: String,
|
||||
pub probability: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct Citation {
|
||||
pub start_index: i32,
|
||||
pub end_index: i32,
|
||||
pub uri: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct CitationMetadata {
|
||||
pub citations: Vec<Citation>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct UsageMetadata {
|
||||
pub candidates_token_count: Option<i32>,
|
||||
pub prompt_token_count: i32,
|
||||
pub total_token_count: i32,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct FunctionDeclaration {
|
||||
pub name: String,
|
||||
pub description: String,
|
||||
pub parameters: FunctionParameters,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct FunctionParameters {
|
||||
pub r#type: String,
|
||||
pub properties: HashMap<String, FunctionParametersProperty>,
|
||||
pub required: Vec<String>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct FunctionParametersProperty {
|
||||
pub r#type: String,
|
||||
pub description: String,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct Error {
|
||||
pub code: i32,
|
||||
pub message: String,
|
||||
pub status: String,
|
||||
pub details: Vec<ErrorType>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct Link {
|
||||
pub description: String,
|
||||
pub url: String,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
#[serde(tag = "@type")]
|
||||
pub enum ErrorType {
|
||||
#[serde(rename = "type.googleapis.com/google.rpc.ErrorInfo")]
|
||||
ErrorInfo { metadata: ErrorInfoMetadata },
|
||||
|
||||
#[serde(rename = "type.googleapis.com/google.rpc.Help")]
|
||||
Help { links: Vec<Link> },
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct ErrorInfoMetadata {
|
||||
service: String,
|
||||
consumer: String,
|
||||
}
|
||||
use std::collections::HashMap;
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use super::{Content, Error, Part};
|
||||
|
||||
#[derive(Clone, Serialize, Deserialize)]
|
||||
pub struct GenerateContentRequest {
|
||||
pub contents: Vec<Content>,
|
||||
pub generation_config: Option<GenerationConfig>,
|
||||
pub tools: Option<Vec<Tools>>,
|
||||
}
|
||||
|
||||
impl GenerateContentRequest {
|
||||
pub fn from_prompt(prompt: &str, generation_config: Option<GenerationConfig>) -> Self {
|
||||
GenerateContentRequest {
|
||||
contents: vec![Content {
|
||||
role: "user".to_string(),
|
||||
parts: Some(vec![Part::Text(prompt.to_string())]),
|
||||
}],
|
||||
generation_config,
|
||||
tools: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Serialize, Deserialize)]
|
||||
pub struct Tools {
|
||||
pub function_declarations: Option<Vec<FunctionDeclaration>>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize, Default)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct GenerationConfig {
|
||||
pub max_output_tokens: Option<i32>,
|
||||
pub temperature: Option<f32>,
|
||||
pub top_p: Option<f32>,
|
||||
pub top_k: Option<i32>,
|
||||
pub stop_sequences: Option<Vec<String>>,
|
||||
pub candidate_count: Option<u32>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct Candidate {
|
||||
pub content: Content,
|
||||
pub citation_metadata: Option<CitationMetadata>,
|
||||
pub safety_ratings: Vec<SafetyRating>,
|
||||
pub finish_reason: Option<String>,
|
||||
}
|
||||
|
||||
impl Candidate {
|
||||
pub fn get_text(&self) -> Option<String> {
|
||||
self.content.get_text()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct Citation {
|
||||
pub start_index: i32,
|
||||
pub end_index: i32,
|
||||
pub uri: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct CitationMetadata {
|
||||
pub citations: Vec<Citation>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct SafetyRating {
|
||||
pub category: String,
|
||||
pub probability: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct UsageMetadata {
|
||||
pub candidates_token_count: Option<i32>,
|
||||
pub prompt_token_count: i32,
|
||||
pub total_token_count: i32,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct FunctionDeclaration {
|
||||
pub name: String,
|
||||
pub description: String,
|
||||
pub parameters: FunctionParameters,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct FunctionParameters {
|
||||
pub r#type: String,
|
||||
pub properties: HashMap<String, FunctionParametersProperty>,
|
||||
pub required: Vec<String>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct FunctionParametersProperty {
|
||||
pub r#type: String,
|
||||
pub description: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
#[serde(untagged)]
|
||||
pub enum GenerateContentResponse {
|
||||
Ok {
|
||||
candidates: Vec<Candidate>,
|
||||
usage_metadata: Option<UsageMetadata>,
|
||||
},
|
||||
Error {
|
||||
error: Error,
|
||||
},
|
||||
}
|
||||
9
src/types/mod.rs
Normal file
9
src/types/mod.rs
Normal file
@@ -0,0 +1,9 @@
|
||||
mod common;
|
||||
mod count_tokens;
|
||||
mod error;
|
||||
mod generate_content;
|
||||
|
||||
pub use common::*;
|
||||
pub use count_tokens::*;
|
||||
pub use error::*;
|
||||
pub use generate_content::*;
|
||||
Reference in New Issue
Block a user