api/auth/
guards.rs

1use axum::{extract::{Path, State, FromRequestParts}, http::{Request, StatusCode}, middleware::Next, body::Body, response::{Response, IntoResponse}, Json};
2use util::state::AppState;
3use crate::auth::claims::AuthUser;
4use crate::response::ApiResponse;
5use sea_orm::DatabaseConnection;
6use std::collections::HashMap;
7use sea_orm::EntityTrait;
8use sea_orm::QueryFilter;
9use sea_orm::ColumnTrait;
10use db::models::{
11    module::Entity as ModuleEntity,
12    plagiarism_case::{Entity as PlagiarismEntity, Column as PlagiarismColumn},
13    assignment::{Entity as AssignmentEntity, Column as AssignmentColumn},
14    assignment_task::{Entity as TaskEntity, Column as TaskColumn},
15    assignment_submission::{Entity as SubmissionEntity, Column as SubmissionColumn},
16    assignment_file::{Entity as FileEntity, Column as FileColumn},
17    user::Entity as UserEntity,
18    user
19};
20
21// --- Role Based Access Guards ---
22
23#[derive(serde::Serialize, Default)]
24pub struct Empty;
25
26/// Helper to extract, validate user from request extensions and insert the back into the request
27async fn extract_and_insert_authuser(
28    mut req: Request<Body>
29) -> Result<(Request<Body>, AuthUser), (StatusCode, Json<ApiResponse<Empty>>)> {
30    let (mut parts, body) = req.into_parts();
31    let user = AuthUser::from_request_parts(&mut parts, &())
32        .await
33        .map_err(|_| (
34            StatusCode::UNAUTHORIZED,
35            Json(ApiResponse::error("Authentication required"))
36        ))?;
37    
38    req = Request::from_parts(parts, body);
39    req.extensions_mut().insert(user.clone());
40    Ok((req, user))
41}
42
43/// Helper to check if user has any of the specified roles
44async fn user_has_any_role(
45    db: &DatabaseConnection,
46    user_id: i64,
47    module_id: i64,
48    roles: &[&str],
49) -> bool {
50    for role in roles {
51        if user::Model::is_in_role(db, user_id, module_id, role).await.unwrap_or(false) {
52            return true;
53        }
54    }
55    false
56}
57
58/// Basic guard to ensure the request is authenticated.
59pub async fn require_authenticated(
60    req: Request<Body>,
61    next: Next,
62) -> Result<Response, (StatusCode, Json<ApiResponse<Empty>>)> {
63    let (req, _user) = extract_and_insert_authuser(req).await?;
64
65    Ok(next.run(req).await)
66}
67
68/// Admin-only guard.
69pub async fn require_admin(
70    req: Request<Body>,
71    next: Next,
72) -> Result<Response, (StatusCode, Json<ApiResponse<Empty>>)> {
73    let (req, user) = extract_and_insert_authuser(req).await?;
74    
75    if !user.0.admin {
76        return Err((
77            StatusCode::FORBIDDEN,
78            Json(ApiResponse::error("Admin access required"))
79        ));
80    }
81
82    Ok(next.run(req).await)
83}
84
85/// Base role-based access guard that other guards can build upon
86async fn require_role_base(
87    State(app_state): State<AppState>,
88    Path(params): Path<HashMap<String, String>>,
89    req: Request<Body>,
90    next: Next,
91    required_roles: &[&str],
92    failure_msg: &str,
93) -> Result<Response, (StatusCode, Json<ApiResponse<Empty>>)> {
94    let db: &DatabaseConnection =  app_state.db();
95
96    let (req, user) = extract_and_insert_authuser(req).await?;
97    
98    let module_id = params.get("module_id")
99        .and_then(|s| s.parse::<i64>().ok())
100        .ok_or((
101            StatusCode::BAD_REQUEST,
102            Json(ApiResponse::error("Missing or invalid module_id"))
103        ))?;
104
105    if user.0.admin {
106        return Ok(next.run(req).await);
107    }
108
109    if user_has_any_role(&db, user.0.sub, module_id, required_roles).await {
110        Ok(next.run(req).await)
111    } else {
112        Err((StatusCode::FORBIDDEN, Json(ApiResponse::error(failure_msg))))
113    }
114}
115
116/// Guard for requiring lecturer access.
117pub async fn require_lecturer(
118    State(app_state): State<AppState>,
119    Path(params): Path<HashMap<String, String>>,
120    req: Request<Body>,
121    next: Next,
122) -> Result<Response, (StatusCode, Json<ApiResponse<Empty>>)> {
123    require_role_base(
124        State(app_state),
125        Path(params),
126        req,
127        next,
128        &["Lecturer"],
129        "Lecturer access required for this module"
130    ).await
131}
132
133/// Guard for requiring assistant lecturer access.
134pub async fn require_assistant_lecturer(
135    State(app_state): State<AppState>,
136    Path(params): Path<HashMap<String, String>>,
137    req: Request<Body>,
138    next: Next,
139) -> Result<Response, (StatusCode, Json<ApiResponse<Empty>>)> {
140    require_role_base(
141        State(app_state),
142        Path(params),
143        req,
144        next,
145        &["AssistantLecturer"],
146        "Assistant lecturer access required for this module"
147    ).await
148}
149
150/// Guard for requiring tutor access.
151pub async fn require_tutor(
152    State(app_state): State<AppState>,
153    Path(params): Path<HashMap<String, String>>,
154    req: Request<Body>,
155    next: Next,
156) -> Result<Response, (StatusCode, Json<ApiResponse<Empty>>)> {
157    require_role_base(
158        State(app_state),
159        Path(params),
160        req,
161        next,
162        &["Tutor"],
163        "Tutor access required for this module"
164    ).await
165}
166
167/// Guard for requiring student access.
168pub async fn require_student(
169    State(app_state): State<AppState>,
170    Path(params): Path<HashMap<String, String>>,
171    req: Request<Body>,
172    next: Next,
173) -> Result<Response, (StatusCode, Json<ApiResponse<Empty>>)> {
174    require_role_base(
175        State(app_state),
176        Path(params),
177        req,
178        next,
179        &["Student"],
180        "Student access required for this module"
181    ).await
182}
183
184/// Guard for requiring lecturer or assistant lecturer access.
185pub async fn require_lecturer_or_assistant_lecturer(
186    State(app_state): State<AppState>,
187    Path(params): Path<HashMap<String, String>>,
188    req: Request<Body>,
189    next: Next,
190) -> Result<Response, (StatusCode, Json<ApiResponse<Empty>>)> {
191    require_role_base(
192        State(app_state),
193        Path(params),
194        req,
195        next,
196        &["Lecturer", "AssistantLecturer"],
197        "Lecturer or assistant lecturer access required for this module"
198    ).await
199}
200
201/// Guard for requiring lecturer or tutor access.
202/// TODO: Add ALs to this?
203pub async fn require_lecturer_or_tutor(
204    State(app_state): State<AppState>,
205    Path(params): Path<HashMap<String, String>>,
206    req: Request<Body>,
207    next: Next,
208) -> Result<Response, (StatusCode, Json<ApiResponse<Empty>>)> {
209    require_role_base(
210        State(app_state),
211        Path(params),
212        req,
213        next,
214        &["Lecturer", "Tutor"],
215        "Lecturer or tutor access required for this module"
216    ).await
217}
218
219/// Guard for requiring any assigned role (lecturer, tutor, student).
220pub async fn require_assigned_to_module(
221    State(app_state): State<AppState>,
222    Path(params): Path<HashMap<String, String>>,
223    req: Request<Body>,
224    next: Next,
225) -> Result<Response, (StatusCode, Json<ApiResponse<Empty>>)> {
226    require_role_base(
227        State(app_state),
228        Path(params),
229        req,
230        next,
231        &["Lecturer", "AssistantLecturer", "Tutor", "Student"],
232        "User not assigned to this module"
233    ).await
234}
235
236pub async fn require_ready_assignment(
237    State(app_state): State<AppState>,
238    Path(params): Path<HashMap<String, String>>,
239    req: Request<Body>,
240    next: Next,
241) -> Result<Response, (StatusCode, Json<ApiResponse<Empty>>)> {
242    let db = app_state.db();
243
244    let module_id = params.get("module_id")
245        .and_then(|s| s.parse::<i64>().ok())
246        .ok_or((
247            StatusCode::BAD_REQUEST,
248            Json(ApiResponse::error("Missing or invalid module_id"))
249        ))?;
250
251    let assignment_id = params.get("assignment_id")
252        .and_then(|s| s.parse::<i64>().ok())
253        .ok_or((
254            StatusCode::BAD_REQUEST,
255            Json(ApiResponse::error("Missing or invalid assignment_id"))
256        ))?;
257
258    if let Err(e) = db::models::assignment::Model::try_transition_to_ready(db, module_id, assignment_id).await {
259        return Err((
260            StatusCode::INTERNAL_SERVER_ERROR,
261            Json(ApiResponse::error(format!("Failed to transition assignment to ready: {}", e)))
262        ));
263    }
264
265    let assignment = match AssignmentEntity::find_by_id(assignment_id).one(db).await {
266        Ok(Some(a)) => a,
267        Ok(None) => {
268            return Err((
269                StatusCode::NOT_FOUND,
270                Json(ApiResponse::error(format!(
271                    "Assignment {} in Module {} not found.",
272                    assignment_id, module_id
273                ))),
274            ));
275        }
276        Err(_) => {
277            return Err((
278                StatusCode::INTERNAL_SERVER_ERROR,
279                Json(ApiResponse::error("Database error while checking assignment")),
280            ));
281        }
282    };
283
284    if assignment.status == db::models::assignment::Status::Setup {
285        return Err((
286            StatusCode::FORBIDDEN,
287            Json(ApiResponse::error("Assignment is still in Setup stage"))
288        ));
289    }
290
291    Ok(next.run(req).await)
292}
293
294// --- Path ID Guards ---
295
296async fn check_module_exists(
297    module_id: i32,
298    db: &DatabaseConnection,
299) -> Result<(), (StatusCode, Json<ApiResponse<Empty>>)> {
300    let found = ModuleEntity::find_by_id(module_id)
301        .one(db)
302        .await
303        .map_err(|_| (StatusCode::INTERNAL_SERVER_ERROR, Json(ApiResponse::error("Database error while checking module"))))?;
304
305    if found.is_none() {
306        return Err((
307            StatusCode::NOT_FOUND,
308            Json(ApiResponse::error(format!("Module {} not found.", module_id))),
309        ));
310    }
311    Ok(())
312}
313
314async fn check_user_exists(
315    user_id: i32,
316    db: &DatabaseConnection,
317) -> Result<(), (StatusCode, Json<ApiResponse<Empty>>)> {
318    let found = UserEntity::find_by_id(user_id)
319        .one(db)
320        .await
321        .map_err(|_| (StatusCode::INTERNAL_SERVER_ERROR, Json(ApiResponse::error("Database error while checking user"))))?;
322
323    if found.is_none() {
324        return Err((
325            StatusCode::NOT_FOUND,
326            Json(ApiResponse::error(format!("User {} not found.", user_id))),
327        ));
328    }
329    Ok(())
330}
331
332async fn check_assignment_hierarchy(
333    module_id: i32,
334    assignment_id: i32,
335    db: &DatabaseConnection,
336) -> Result<(), (StatusCode, Json<ApiResponse<Empty>>)> {
337    check_module_exists(module_id, db).await?;
338
339    let found = AssignmentEntity::find()
340        .filter(AssignmentColumn::Id.eq(assignment_id))
341        .filter(AssignmentColumn::ModuleId.eq(module_id))
342        .one(db)
343        .await
344        .map_err(|_| (StatusCode::INTERNAL_SERVER_ERROR, Json(ApiResponse::error("Database error while checking assignment"))))?;
345
346    if found.is_none() {
347        return Err((
348            StatusCode::NOT_FOUND,
349            Json(ApiResponse::error(format!("Assignment {} in Module {} not found.", assignment_id, module_id))),
350        ));
351    }
352    Ok(())
353}
354
355async fn check_task_hierarchy(
356    module_id: i32,
357    assignment_id: i32,
358    task_id: i32,
359    db: &DatabaseConnection,
360) -> Result<(), (StatusCode, Json<ApiResponse<Empty>>)> {
361    check_assignment_hierarchy(module_id, assignment_id, db).await?;
362
363    let found = TaskEntity::find()
364        .filter(TaskColumn::Id.eq(task_id))
365        .filter(TaskColumn::AssignmentId.eq(assignment_id))
366        .one(db)
367        .await
368        .map_err(|_| (StatusCode::INTERNAL_SERVER_ERROR, Json(ApiResponse::error("Database error while checking task"))))?;
369
370    if found.is_none() {
371        return Err((
372            StatusCode::NOT_FOUND,
373            Json(ApiResponse::error(format!("Task {} in Assignment {} not found.", task_id, assignment_id))),
374        ));
375    }
376    Ok(())
377}
378
379async fn check_submission_hierarchy(
380    module_id: i32,
381    assignment_id: i32,
382    submission_id: i32,
383    db: &DatabaseConnection,
384) -> Result<(), (StatusCode, Json<ApiResponse<Empty>>)> {
385    check_assignment_hierarchy(module_id, assignment_id, db).await?;
386
387    let found = SubmissionEntity::find()
388        .filter(SubmissionColumn::Id.eq(submission_id))
389        .filter(SubmissionColumn::AssignmentId.eq(assignment_id))
390        .one(db)
391        .await
392        .map_err(|_| (StatusCode::INTERNAL_SERVER_ERROR, Json(ApiResponse::error("Database error while checking submission"))))?;
393
394    if found.is_none() {
395        return Err((
396            StatusCode::NOT_FOUND,
397            Json(ApiResponse::error(format!("Submission {} in Assignment {} not found.", submission_id, assignment_id))),
398        ));
399    }
400    Ok(())
401}
402
403async fn check_file_hierarchy(
404    module_id: i32,
405    assignment_id: i32,
406    file_id: i32,
407    db: &DatabaseConnection,
408) -> Result<(), (StatusCode, Json<ApiResponse<Empty>>)> {
409    check_assignment_hierarchy(module_id, assignment_id, db).await?;
410
411    let found = FileEntity::find()
412        .filter(FileColumn::Id.eq(file_id))
413        .filter(FileColumn::AssignmentId.eq(assignment_id))
414        .one(db)
415        .await
416        .map_err(|_| (StatusCode::INTERNAL_SERVER_ERROR, Json(ApiResponse::error("Database error while checking file"))))?;
417
418    if found.is_none() {
419        return Err((
420            StatusCode::NOT_FOUND,
421            Json(ApiResponse::error(format!("File {} in Assignment {} not found.", file_id, assignment_id))),
422        ));
423    }
424    Ok(())
425}
426
427pub async fn check_ticket_hierarchy(
428    module_id: i32,
429    assignment_id: i32,
430    ticket_id: i32,
431    db: &DatabaseConnection,
432) -> Result<(), (StatusCode, Json<ApiResponse<Empty>>)> {
433    check_assignment_hierarchy(module_id, assignment_id, db).await?;
434
435    let found = db::models::tickets::Entity::find()
436        .filter(db::models::tickets::Column::Id.eq(ticket_id))
437        .filter(db::models::tickets::Column::AssignmentId.eq(assignment_id))
438        .one(db)
439        .await
440        .map_err(|_| (StatusCode::INTERNAL_SERVER_ERROR, Json(ApiResponse::error("Database error while checking ticket"))))?;
441
442    if found.is_none() {
443        return Err((
444            StatusCode::NOT_FOUND,
445            Json(ApiResponse::error(format!("Ticket {} in Assignment {} not found.", ticket_id, assignment_id))),
446        ));
447    }
448    Ok(())
449}
450
451pub async fn check_message_hierarchy(
452    module_id: i32,
453    assignment_id: i32,
454    ticket_id: i32,
455    message_id: i32,
456    db: &DatabaseConnection,
457) -> Result<(), (StatusCode, Json<ApiResponse<Empty>>)> {
458    check_ticket_hierarchy(module_id, assignment_id, ticket_id, db).await?;
459
460    let found = db::models::ticket_messages::Entity::find()
461        .filter(db::models::ticket_messages::Column::Id.eq(message_id))
462        .filter(db::models::ticket_messages::Column::TicketId.eq(ticket_id))
463        .one(db)
464        .await
465        .map_err(|_| (StatusCode::INTERNAL_SERVER_ERROR, Json(ApiResponse::error("Database error while checking message"))))?;
466
467    if found.is_none() {
468        return Err((
469            StatusCode::NOT_FOUND,
470            Json(ApiResponse::error(format!("Message {} in Ticket {} not found.", message_id, ticket_id))),
471        ));
472    }
473    Ok(())
474}
475
476pub async fn check_plagiarism_hierarchy(
477    module_id: i32,
478    assignment_id: i32,
479    case_id: i32,
480    db: &DatabaseConnection,
481) -> Result<(), (StatusCode, Json<ApiResponse<Empty>>)> {
482    check_assignment_hierarchy(module_id, assignment_id, db).await?;
483
484    let found = PlagiarismEntity::find()
485        .filter(PlagiarismColumn::Id.eq(case_id))
486        .filter(PlagiarismColumn::AssignmentId.eq(assignment_id))
487        .one(db)
488        .await
489        .map_err(|_| (StatusCode::INTERNAL_SERVER_ERROR, Json(ApiResponse::error("Database error while checking plagiarism case"))))?;
490
491    if found.is_none() {
492        return Err((
493            StatusCode::NOT_FOUND,
494            Json(ApiResponse::error(format!("Plagiarism case {} in Assignment {} not found.", case_id, assignment_id))),
495        ));
496    }
497    Ok(())
498}
499
500pub async fn check_announcement_hierarchy(
501    module_id: i32,
502    announcement_id: i32,
503    db: &DatabaseConnection,
504) -> Result<(), (StatusCode, Json<ApiResponse<Empty>>)> {
505    check_module_exists(module_id, db).await?;
506
507    let found = db::models::announcements::Entity::find()
508        .filter(db::models::announcements::Column::Id.eq(announcement_id))
509        .filter(db::models::announcements::Column::ModuleId.eq(module_id))
510        .one(db)
511        .await
512        .map_err(|_| (StatusCode::INTERNAL_SERVER_ERROR, Json(ApiResponse::error("Database error while checking announcement"))))?;
513
514    if found.is_none() {
515        return Err((
516            StatusCode::NOT_FOUND,
517            Json(ApiResponse::error(format!("Announcement {} in Module {} not found.", announcement_id, module_id))),
518        ));
519    }
520    Ok(())
521}
522
523pub async fn validate_known_ids(
524    State(app_state): State<AppState>,
525    Path(params): Path<HashMap<String, String>>,
526    req: Request<Body>,
527    next: Next,
528) -> Result<Response, Response> {
529    let db = app_state.db();
530
531    let mut module_id: Option<i32>     = None;
532    let mut assignment_id: Option<i32> = None;
533    let mut task_id: Option<i32>       = None;
534    let mut submission_id: Option<i32> = None;
535    let mut file_id: Option<i32>       = None;
536    let mut user_id: Option<i32>       = None;
537    let mut ticket_id: Option<i32>     = None;
538    let mut message_id: Option<i32>    = None;
539    let mut case_id: Option<i32>       = None;
540    let mut announcement_id: Option<i32> = None;
541
542    for (key, raw) in &params {
543        let id = raw.parse::<i32>().map_err(|_| {
544            (StatusCode::BAD_REQUEST, Json(ApiResponse::<Empty>::error(format!("Invalid {}: '{}'. Must be an integer.", key, raw)))).into_response()
545        })?;
546        match key.as_str() {
547            "module_id"     => module_id = Some(id),
548            "assignment_id" => assignment_id = Some(id),
549            "task_id"       => task_id = Some(id),
550            "submission_id" => submission_id = Some(id),
551            "file_id"       => file_id = Some(id),
552            "user_id"       => user_id = Some(id),
553            "ticket_id"     => ticket_id = Some(id),
554            "case_id" => case_id = Some(id),
555            "announcement_id" => announcement_id = Some(id),
556            "message_id" => message_id = Some(id),
557            _ => return Err((StatusCode::BAD_REQUEST, Json(ApiResponse::<Empty>::error(format!("Unexpected parameter: '{}'.", key)))).into_response()),
558        }
559    }
560    
561    if let Some(uid) = user_id {
562        check_user_exists(uid, db).await.map_err(|e| e.into_response())?;
563    }
564    if let Some(mid) = module_id {
565        check_module_exists(mid, db).await.map_err(|e| e.into_response())?;
566    }
567    if let (Some(mid), Some(aid)) = (module_id, assignment_id) {
568        check_assignment_hierarchy(mid, aid, db).await.map_err(|e| e.into_response())?;
569    }
570    if let (Some(mid), Some(aid), Some(tid)) = (module_id, assignment_id, task_id) {
571        check_task_hierarchy(mid, aid, tid, db).await.map_err(|e| e.into_response())?;
572    }
573    if let (Some(mid), Some(aid), Some(sid)) = (module_id, assignment_id, submission_id) {
574        check_submission_hierarchy(mid, aid, sid, db).await.map_err(|e| e.into_response())?;
575    }
576    if let (Some(mid), Some(aid), Some(fid)) = (module_id, assignment_id, file_id) {
577        check_file_hierarchy(mid, aid, fid, db).await.map_err(|e| e.into_response())?;
578    }
579    if let (Some(mid), Some(aid), Some(tid)) = (module_id, assignment_id, ticket_id) {
580        check_ticket_hierarchy(mid, aid, tid, db).await.map_err(|e| e.into_response())?;
581    }
582    if let (Some(mid), Some(aid), Some(sid)) = (module_id, assignment_id, case_id) {
583        check_plagiarism_hierarchy(mid, aid, sid, db).await.map_err(|e| e.into_response())?;
584    }
585
586    if let (Some(mid), Some(ann_id)) = (module_id, announcement_id) {
587        check_announcement_hierarchy(mid, ann_id, db).await.map_err(|e| e.into_response())?;
588    }
589
590    if let (Some(mid), Some(aid), Some(tid), Some(meid)) = (module_id, assignment_id, ticket_id, message_id) {
591        check_message_hierarchy(mid, aid, tid, meid, db).await.map_err(|e| e.into_response())?;
592    }
593
594    Ok(next.run(req).await)
595}
596
597// TODO Write tests for this gaurd
598pub async fn require_ticket_ws_access(
599    State(app_state): State<AppState>,
600    Path(params): Path<HashMap<String, String>>,
601    req: axum::http::Request<Body>,
602    next: Next,
603) -> Result<Response, (StatusCode, Json<ApiResponse<Empty>>)> {
604    let db = app_state.db();
605
606    // Must be logged in (also inserts AuthUser into extensions)
607    let (req, user) = extract_and_insert_authuser(req).await?;
608
609    // ticket_id from path
610    let ticket_id = params.get("ticket_id")
611        .and_then(|s| s.parse::<i64>().ok())
612        .ok_or((
613            StatusCode::BAD_REQUEST,
614            Json(ApiResponse::error("Missing or invalid ticket_id")),
615        ))?;
616
617    // Load ticket -> get assignment_id and author
618    let ticket = db::models::tickets::Entity::find_by_id(ticket_id)
619        .one(db)
620        .await
621        .map_err(|_| (StatusCode::INTERNAL_SERVER_ERROR, Json(ApiResponse::error("Database error while checking ticket"))))?
622        .ok_or((StatusCode::NOT_FOUND, Json(ApiResponse::error("Ticket not found"))))?;
623
624    // Author can access
625    if ticket.user_id == user.0.sub {
626        return Ok(next.run(req).await);
627    }
628
629    // Admin can access
630    if user.0.admin {
631        return Ok(next.run(req).await);
632    }
633
634    // Resolve module via assignment -> module_id
635    let assignment = db::models::assignment::Entity::find_by_id(ticket.assignment_id)
636        .one(db)
637        .await
638        .map_err(|_| (StatusCode::INTERNAL_SERVER_ERROR, Json(ApiResponse::error("Database error while checking assignment"))))?
639        .ok_or((StatusCode::NOT_FOUND, Json(ApiResponse::error("Assignment not found for ticket"))))?;
640
641    let module_id = assignment.module_id;
642
643    // Allow module staff (Lecturer, AssistantLecturer, Tutor)
644    if user_has_any_role(db, user.0.sub, module_id, &["Lecturer", "AssistantLecturer", "Tutor"]).await {
645        return Ok(next.run(req).await);
646    }
647
648    // Otherwise, deny
649    Err((StatusCode::FORBIDDEN, Json(ApiResponse::error("Not allowed to access this ticket websocket"))))
650}