Add the text_embeddings API

This commit is contained in:
2024-02-28 14:03:30 +00:00
parent b8f2b7e85e
commit 627ce368b4
4 changed files with 108 additions and 0 deletions

View File

@@ -0,0 +1,40 @@
use std::sync::Arc;
use gemini_rs::prelude::*;
use gcp_auth::AuthenticationManager;
#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
let authentication_manager = Arc::new(AuthenticationManager::new().await?);
let api_endpoint = std::env::var("API_ENDPOINT")?;
let project_id = std::env::var("PROJECT_ID")?;
let location_id = std::env::var("LOCATION_ID")?;
let gemini = GeminiClient::new(
authentication_manager,
api_endpoint,
project_id,
location_id,
);
let embedding_request = TextEmbeddingRequest {
instances: vec![
TextEmbeddingRequestInstance {
title: String::from("Embed testing"),
content: String::from("Embed testing"),
task_type: String::from("RETRIEVAL_DOCUMENT"),
},
TextEmbeddingRequestInstance {
title: String::from("Embed testing 2"),
content: String::from("Embed testing 2"),
task_type: String::from("RETRIEVAL_DOCUMENT"),
},
],
};
let result = gemini.text_embeddings(&embedding_request).await?;
println!("Response: {:?}", result);
Ok(())
}

View File

@@ -8,6 +8,7 @@ use crate::dialogue::{Message, Role};
use crate::error::{Error, Result}; use crate::error::{Error, Result};
use crate::prelude::{ use crate::prelude::{
Candidate, Content, GenerateContentRequest, GenerateContentResponse, GenerationConfig, Candidate, Content, GenerateContentRequest, GenerateContentResponse, GenerationConfig,
TextEmbeddingRequest, TextEmbeddingResponse,
}; };
use crate::{prelude::Part, token_provider::TokenProvider}; use crate::{prelude::Part, token_provider::TokenProvider};
@@ -171,4 +172,26 @@ impl<T: TokenProvider + Clone> GeminiClient<T> {
} }
} }
} }
pub async fn text_embeddings(
&self,
request: &TextEmbeddingRequest,
) -> Result<TextEmbeddingResponse> {
let model = "textembedding-gecko@003";
let endpoint_url = format!(
"https://{}/v1/projects/{}/locations/{}/publishers/google/models/{}:predict",
self.api_endpoint, self.project_id, self.location_id, model,
);
let access_token = self.token_provider.get_token(AUTH_SCOPE).await?;
let resp = self
.client
.post(&endpoint_url)
.bearer_auth(access_token)
.json(&request)
.send()
.await?;
let txt_json = resp.text().await?;
println!("{}", txt_json);
Ok(serde_json::from_str::<TextEmbeddingResponse>(&txt_json)?)
}
} }

View File

@@ -2,8 +2,10 @@ mod common;
mod count_tokens; mod count_tokens;
mod error; mod error;
mod generate_content; mod generate_content;
mod text_embeddings;
pub use common::*; pub use common::*;
pub use count_tokens::*; pub use count_tokens::*;
pub use error::*; pub use error::*;
pub use generate_content::*; pub use generate_content::*;
pub use text_embeddings::*;

View File

@@ -0,0 +1,43 @@
use serde::{Deserialize, Serialize};
use super::Error;
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct TextEmbeddingRequest {
pub instances: Vec<TextEmbeddingRequestInstance>,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct TextEmbeddingRequestInstance {
pub content: String,
pub task_type: String,
pub title: String,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(untagged)]
pub enum TextEmbeddingResponse {
Ok {
predictions: Vec<TextEmbeddingPrediction>,
},
Error {
error: Error,
},
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct TextEmbeddingPrediction {
pub embeddings: TextEmbeddingResult,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct TextEmbeddingResult {
statistics: TextEmbeddingStatistics,
values: Vec<f32>,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct TextEmbeddingStatistics {
truncated: bool,
token_count: u32,
}