Adds a count_token methods to the Gemini Clent
This commit is contained in:
32
examples/count-tokens.rs
Normal file
32
examples/count-tokens.rs
Normal file
@@ -0,0 +1,32 @@
|
|||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
use gemini_rs::prelude::*;
|
||||||
|
|
||||||
|
use gcp_auth::AuthenticationManager;
|
||||||
|
|
||||||
|
#[tokio::main]
|
||||||
|
async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||||
|
let authentication_manager = Arc::new(AuthenticationManager::new().await?);
|
||||||
|
let api_endpoint = std::env::var("API_ENDPOINT")?;
|
||||||
|
let project_id = std::env::var("PROJECT_ID")?;
|
||||||
|
let location_id = std::env::var("LOCATION_ID")?;
|
||||||
|
|
||||||
|
let gemini = GeminiClient::new(
|
||||||
|
authentication_manager,
|
||||||
|
api_endpoint,
|
||||||
|
project_id,
|
||||||
|
location_id,
|
||||||
|
);
|
||||||
|
|
||||||
|
let prompt = "What is the airspeed of an unladen swallow?";
|
||||||
|
let request = CountTokensRequest {
|
||||||
|
contents: Content {
|
||||||
|
role: "user".to_string(),
|
||||||
|
parts: Some(vec![Part::Text(prompt.to_string())]),
|
||||||
|
},
|
||||||
|
};
|
||||||
|
let result = gemini.count_tokens(&request, "gemini-pro").await?;
|
||||||
|
println!("Response: {:?}", result);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
@@ -7,8 +7,8 @@ use reqwest_eventsource::{Event, EventSource};
|
|||||||
use crate::dialogue::{Message, Role};
|
use crate::dialogue::{Message, Role};
|
||||||
use crate::error::{Error, Result};
|
use crate::error::{Error, Result};
|
||||||
use crate::prelude::{
|
use crate::prelude::{
|
||||||
Candidate, Content, GenerateContentRequest, GenerateContentResponse, GenerationConfig,
|
Candidate, Content, CountTokensRequest, CountTokensResponse, GenerateContentRequest,
|
||||||
TextEmbeddingRequest, TextEmbeddingResponse,
|
GenerateContentResponse, GenerationConfig, TextEmbeddingRequest, TextEmbeddingResponse,
|
||||||
};
|
};
|
||||||
use crate::{prelude::Part, token_provider::TokenProvider};
|
use crate::{prelude::Part, token_provider::TokenProvider};
|
||||||
|
|
||||||
@@ -191,6 +191,30 @@ impl<T: TokenProvider + Clone> GeminiClient<T> {
|
|||||||
.send()
|
.send()
|
||||||
.await?;
|
.await?;
|
||||||
let txt_json = resp.text().await?;
|
let txt_json = resp.text().await?;
|
||||||
|
tracing::debug!("text_embeddings response: {:?}", txt_json);
|
||||||
Ok(serde_json::from_str::<TextEmbeddingResponse>(&txt_json)?)
|
Ok(serde_json::from_str::<TextEmbeddingResponse>(&txt_json)?)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub async fn count_tokens(
|
||||||
|
&self,
|
||||||
|
request: &CountTokensRequest,
|
||||||
|
model: &str,
|
||||||
|
) -> Result<CountTokensResponse> {
|
||||||
|
let endpoint_url = format!(
|
||||||
|
"https://{}/v1beta1/projects/{}/locations/{}/publishers/google/models/{}:countTokens",
|
||||||
|
self.api_endpoint, self.project_id, self.location_id, model,
|
||||||
|
);
|
||||||
|
let access_token = self.token_provider.get_token(AUTH_SCOPE).await?;
|
||||||
|
let resp = self
|
||||||
|
.client
|
||||||
|
.post(&endpoint_url)
|
||||||
|
.bearer_auth(access_token)
|
||||||
|
.json(&request)
|
||||||
|
.send()
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
let txt_json = resp.text().await?;
|
||||||
|
tracing::debug!("count_tokens response: {:?}", txt_json);
|
||||||
|
Ok(serde_json::from_str(&txt_json)?)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -8,7 +8,14 @@ pub struct CountTokensRequest {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Serialize, Deserialize)]
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
#[serde(rename_all = "camelCase")]
|
#[serde(untagged)]
|
||||||
pub struct CountTokensResponse {
|
pub enum CountTokensResponse {
|
||||||
pub total_tokens: i32,
|
#[serde(rename_all = "camelCase")]
|
||||||
|
Ok {
|
||||||
|
total_tokens: i32,
|
||||||
|
total_billable_characters: u32,
|
||||||
|
},
|
||||||
|
Error {
|
||||||
|
error: super::Error,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user