diff --git a/examples/text-embedding.rs b/examples/text-embedding.rs new file mode 100644 index 0000000..e362ac4 --- /dev/null +++ b/examples/text-embedding.rs @@ -0,0 +1,40 @@ +use std::sync::Arc; + +use gemini_rs::prelude::*; + +use gcp_auth::AuthenticationManager; + +#[tokio::main] +async fn main() -> Result<(), Box> { + let authentication_manager = Arc::new(AuthenticationManager::new().await?); + let api_endpoint = std::env::var("API_ENDPOINT")?; + let project_id = std::env::var("PROJECT_ID")?; + let location_id = std::env::var("LOCATION_ID")?; + + let gemini = GeminiClient::new( + authentication_manager, + api_endpoint, + project_id, + location_id, + ); + + let embedding_request = TextEmbeddingRequest { + instances: vec![ + TextEmbeddingRequestInstance { + title: String::from("Embed testing"), + content: String::from("Embed testing"), + task_type: String::from("RETRIEVAL_DOCUMENT"), + }, + TextEmbeddingRequestInstance { + title: String::from("Embed testing 2"), + content: String::from("Embed testing 2"), + task_type: String::from("RETRIEVAL_DOCUMENT"), + }, + ], + }; + + let result = gemini.text_embeddings(&embedding_request).await?; + println!("Response: {:?}", result); + + Ok(()) +} diff --git a/src/client.rs b/src/client.rs index 2ea3782..dea9d1d 100644 --- a/src/client.rs +++ b/src/client.rs @@ -8,6 +8,7 @@ use crate::dialogue::{Message, Role}; use crate::error::{Error, Result}; use crate::prelude::{ Candidate, Content, GenerateContentRequest, GenerateContentResponse, GenerationConfig, + TextEmbeddingRequest, TextEmbeddingResponse, }; use crate::{prelude::Part, token_provider::TokenProvider}; @@ -171,4 +172,26 @@ impl GeminiClient { } } } + + pub async fn text_embeddings( + &self, + request: &TextEmbeddingRequest, + ) -> Result { + let model = "textembedding-gecko@003"; + let endpoint_url = format!( + "https://{}/v1/projects/{}/locations/{}/publishers/google/models/{}:predict", + self.api_endpoint, self.project_id, self.location_id, model, + ); + let access_token = self.token_provider.get_token(AUTH_SCOPE).await?; + let resp = self + .client + .post(&endpoint_url) + .bearer_auth(access_token) + .json(&request) + .send() + .await?; + let txt_json = resp.text().await?; + println!("{}", txt_json); + Ok(serde_json::from_str::(&txt_json)?) + } } diff --git a/src/types/mod.rs b/src/types/mod.rs index 406384c..42712fd 100644 --- a/src/types/mod.rs +++ b/src/types/mod.rs @@ -2,8 +2,10 @@ mod common; mod count_tokens; mod error; mod generate_content; +mod text_embeddings; pub use common::*; pub use count_tokens::*; pub use error::*; pub use generate_content::*; +pub use text_embeddings::*; diff --git a/src/types/text_embeddings.rs b/src/types/text_embeddings.rs new file mode 100644 index 0000000..3f1fb12 --- /dev/null +++ b/src/types/text_embeddings.rs @@ -0,0 +1,43 @@ +use serde::{Deserialize, Serialize}; + +use super::Error; + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct TextEmbeddingRequest { + pub instances: Vec, +} + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct TextEmbeddingRequestInstance { + pub content: String, + pub task_type: String, + pub title: String, +} + +#[derive(Clone, Debug, Serialize, Deserialize)] +#[serde(untagged)] +pub enum TextEmbeddingResponse { + Ok { + predictions: Vec, + }, + Error { + error: Error, + }, +} + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct TextEmbeddingPrediction { + pub embeddings: TextEmbeddingResult, +} + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct TextEmbeddingResult { + statistics: TextEmbeddingStatistics, + values: Vec, +} + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct TextEmbeddingStatistics { + truncated: bool, + token_count: u32, +}